Пример #1
0
def main():
    global best_prec1, args

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    if not os.path.isdir(args.checkpoint) and args.local_rank == 0:
        mkdir_p(args.checkpoint)

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()

    # args.lr = float(args.lr * float(args.batch_size * args.world_size) / 256.)  # default args.lr = 0.1 -> 256
    optimizer = set_optimizer(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()


    model, optimizer = amp.initialize(model, optimizer,
                                      opt_level=args.opt_level,
                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                      loss_scale=args.loss_scale,
                                      verbosity = 0)

    model = DDP(model, delay_allreduce=True)


    # optionally resume from a checkpoint
    title = 'ImageNet-' + args.arch
    args.lastepoch =-1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            args.lastepoch = checkpoint['epoch']
            if args.local_rank == 0:
                logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        if args.local_rank == 0:
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
            logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.', 'Valid Top5.'])

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if(args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    pipe = HybridTrainPipe(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, data_dir=traindir, crop=crop_size, dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.test_batch, num_threads=4, device_id=args.local_rank, data_dir=valdir, crop=crop_size, size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    train_loader_len = int(train_loader._size / args.batch_size)
    if args.resume:
        scheduler = CosineAnnealingLR(optimizer, args.epochs, train_loader_len,
                                      eta_min=0., last_epoch=args.lastepoch, warmup=args.warmup)
    else:
        scheduler = CosineAnnealingLR(optimizer,
                                      args.epochs, train_loader_len, eta_min=0., warmup=args.warmup)
    total_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch

        if args.local_rank == 0:
            print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        [train_loss, train_acc, avg_train_time] = train(train_loader, model, criterion, optimizer, epoch,scheduler)
        total_time.update(avg_train_time)
        # evaluate on validation set
        [test_loss, prec1, prec5] = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            # append logger file
            logger.append([optimizer.param_groups[0]['lr'], train_loss, test_loss, train_acc, prec1, prec5])

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best,checkpoint=args.checkpoint)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(prec1, prec5, args.total_batch_size / total_time.avg))

        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()

    if args.local_rank == 0:
        logger.close()
Пример #2
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint) and args.local_rank == 0:
        mkdir_p(args.checkpoint)

    args.distributed = True
    args.gpu = args.local_rank
    torch.cuda.set_device(args.gpu)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.world_size = torch.distributed.get_world_size()
    print('world_size = ', args.world_size)

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif 'resnext' in args.arch:
        model = models.__dict__[args.arch](
            baseWidth=args.base_width,
            cardinality=args.cardinality,
        )
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    flops, params = get_model_complexity_info(model, (224, 224),
                                              as_strings=False,
                                              print_per_layer_stat=False)
    print('Flops:  %.3f' % (flops / 1e9))
    print('Params: %.2fM' % (params / 1e6))

    cudnn.benchmark = True
    # define loss function (criterion) and optimizer
    # criterion = nn.CrossEntropyLoss().cuda()
    criterion = SoftCrossEntropyLoss(label_smoothing=0.1).cuda()
    model = model.cuda()

    optimizer = set_optimizer(model)
    #optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    args.lr = float(0.1 * float(args.train_batch * args.world_size) / 256.)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    #model = torch.nn.DataParallel(model).cuda()
    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
    model = DDP(model, delay_allreduce=True)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'valf')
    #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    data_aug_scale = (0.08, 1.0) if args.modelsize == 'large' else (0.2, 1.0)

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224, scale=data_aug_scale),
            transforms.RandomHorizontalFlip(),
            # transforms.ToTensor(),
            # normalize,
        ]))
    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            # transforms.ToTensor(),
            # normalize,
        ]))

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               collate_fn=fast_collate)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             collate_fn=fast_collate)

    # Resume
    title = 'ImageNet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..', args.resume)
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        # model may have more keys
        t = model.state_dict()
        c = checkpoint['state_dict']
        flag = True
        for k in t:
            if k not in c:
                print('not in loading dict! fill it', k, t[k])
                c[k] = t[k]
                flag = False
        model.load_state_dict(c)
        if flag:
            print('optimizer load old state')
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('new optimizer !')
        if args.local_rank == 0:
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
    else:
        if args.local_rank == 0:
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title)
            logger.set_names([
                'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
                'Valid Acc.'
            ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(val_loader, model, criterion, start_epoch,
                                   use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)

        adjust_learning_rate(optimizer, epoch)

        if args.local_rank == 0:
            print('\nEpoch: [%d | %d] LR: %f' %
                  (epoch + 1, args.epochs, state['lr']))

        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, use_cuda)
        test_loss, test_acc = test(val_loader, model, criterion, epoch,
                                   use_cuda)

        # save model
        if args.local_rank == 0:
            # append logger file
            logger.append(
                [state['lr'], train_loss, test_loss, train_acc, test_acc])

            is_best = test_acc > best_acc
            best_acc = max(test_acc, best_acc)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'acc': test_acc,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                checkpoint=args.checkpoint)

    if args.local_rank == 0:
        logger.close()

    print('Best acc:')
    print(best_acc)
Пример #3
0
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.lr,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    trainer = BertTrainer(model, optimizer, processor, args)

    if not args.trained_model:
        trainer.train()
        model = torch.load(trainer.snapshot_path)
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.model, num_labels=args.num_labels)
        model_ = torch.load(args.trained_model,
                            map_location=lambda storage, loc: storage)
        state = {}
        for key in model_.state_dict().keys():
            new_key = key.replace("module.", "")
            state[new_key] = model_.state_dict()[key]
        model.load_state_dict(state)
        model = model.to(device)

    evaluate_split(model, processor, args, split='dev')
    evaluate_split(model, processor, args, split='test')
Пример #4
0
def do_main():
    # Set default configuration in args.py
    args = get_args()

    if args.local_rank == -1 or not args.cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and args.cuda else "cpu")
        n_gpu = torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    print('Device:', str(device).upper())
    print('Number of GPUs:', n_gpu)
    print('Distributed training:', bool(args.local_rank != -1))
    print('FP16:', args.fp16)

    # Set random seed for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    dataset_map = {
        'SST-2': SST2Processor,
        'Reuters': ReutersProcessor,
        'IMDB': IMDBProcessor,
        'AAPD': AAPDProcessor,
        'AGNews': AGNewsProcessor,
        'Yelp2014': Yelp2014Processor,
        'Sogou': SogouProcessor,
        'Personality': PersonalityProcessor,
        'News_art': News_artProcessor,
        'News': News_Processor,
        'UCI_yelp': UCI_yelpProcessor,
        'Procon': ProconProcessor,
        'Style': StyleProcessor,
        'ProconDual': ProconDualProcessor,
        'Pan15': Pan15_Processor,
        'Pan14E': Pan14E_Processor,
        'Pan14N': Pan14N_Processor,
        'Perspectrum': PerspectrumProcessor
    }

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    if args.dataset not in dataset_map:
        raise ValueError('Unrecognized dataset')

    args.batch_size = args.batch_size // args.gradient_accumulation_steps
    args.device = device
    args.n_gpu = n_gpu
    args.num_labels = dataset_map[args.dataset].NUM_CLASSES
    args.is_multilabel = dataset_map[args.dataset].IS_MULTILABEL

    if not args.trained_model:
        save_path = os.path.join(args.save_path,
                                 dataset_map[args.dataset].NAME)
        os.makedirs(save_path, exist_ok=True)

    processor = dataset_map[args.dataset]()
    args.is_lowercase = 'uncased' in args.model
    args.is_hierarchical = False
    tokenizer = BertTokenizer.from_pretrained(args.model,
                                              is_lowercase=args.is_lowercase)

    train_examples = None
    num_train_optimization_steps = None
    if not args.trained_model:
        train_examples = processor.get_train_examples(args.data_dir,
                                                      args.train_name)
        num_train_optimization_steps = int(
            len(train_examples) / args.batch_size /
            args.gradient_accumulation_steps) * args.epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    model = BertForSequenceClassification.from_pretrained(
        args.model, cache_dir=cache_dir, num_labels=args.num_labels)

    if args.fp16:
        model.half()
    model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Install NVIDIA Apex to use distributed and FP16 training.")
        model = DDP(model)
    '''elif n_gpu > 1: changed by marjan

        model = torch.nn.DataParallel(model)'''

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install NVIDIA Apex for distributed and FP16 training")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.lr,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.lr,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    trainer = BertTrainer(model, optimizer, processor, args)

    if not args.trained_model:
        trainer.train()
        model = torch.load(trainer.snapshot_path)
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.model, num_labels=args.num_labels)
        model_ = torch.load(args.trained_model,
                            map_location=lambda storage, loc: storage)
        state = {}
        for key in model_.state_dict().keys():
            new_key = key.replace("module.", "")
            state[new_key] = model_.state_dict()[key]
        model.load_state_dict(state)
        model = model.to(device)

    evaluate_split(model, processor, args, split=args.dev_name)
    evaluate_split(model, processor, args, split=args.test_name)
Пример #5
0
class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 cuda=True,
                 distributed=False,
                 intra_epoch_eval=0,
                 iter_size=1,
                 translator=None,
                 verbose=False):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param scheduler_config: dictionary with options for the learning rate
            scheduler
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param batch_first: if True the model uses (batch,seq,feature) tensors,
            if false the model uses (seq, batch, feature)
        :param save_info: dict with additional state stored in each checkpoint
        :param save_path: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param cuda: if True use cuda, if False train on cpu
        :param distributed: if True run distributed training
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param iter_size: number of iterations between weight updates
        :param translator: instance of Translator, runs inference on test set
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if math == 'fp16':
            self.model = self.model.half()

        if distributed:
            self.model = DDP(self.model)

        if math == 'fp16':
            self.fp_optimizer = Fp16Optimizer(self.model, grad_clip)
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32':
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)

    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        src, src_length = src
        tgt, tgt_length = tgt
        src_length = torch.LongTensor(src_length)
        tgt_length = torch.LongTensor(tgt_length)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        if self.cuda:
            src = src.cuda()
            src_length = src_length.cuda()
            tgt = tgt.cuda()

        if self.batch_first:
            output = self.model(src, src_length, tgt[:, :-1])
            tgt_labels = tgt[:, 1:]
            T, B = output.size(1), output.size(0)
        else:
            output = self.model(src, src_length, tgt[:-1])
            tgt_labels = tgt[1:]
            T, B = output.size(0), output.size(1)

        loss = self.criterion(output.view(T * B, -1),
                              tgt_labels.contiguous().view(-1))

        loss_per_batch = loss.item()
        loss /= (B * self.iter_size)

        if training:
            self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                   update)

        loss_per_token = loss_per_batch / num_toks['tgt']
        loss_per_sentence = loss_per_batch / B

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval+2)[1:-1]
            iters_with_update = len(data_loader) // self.iter_size
            eval_iters = (eval_fractions * iters_with_update).astype(int)
            eval_iters = eval_iters * self.iter_size
            eval_iters = set(eval_iters)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_per_token = AverageMeter(skip_first=False)
        losses_per_sentence = AverageMeter(skip_first=False)

        tot_tok_time = AverageMeter()
        src_tok_time = AverageMeter()
        tgt_tok_time = AverageMeter()

        batch_size = data_loader.batch_size

        end = time.time()
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            update = False
            if i % self.iter_size == self.iter_size - 1:
                update = True

            # do a train/evaluate iteration
            stats = self.iterate(src, tgt, update, training=training)
            loss_per_token, loss_per_sentence, num_toks = stats

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg

            if training and i in eval_iters:
                test_bleu, _ = self.translator.run(calc_bleu=True,
                                                   epoch=self.epoch,
                                                   iteration=i)

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader, training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})']
                log += [f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})']
                if self.verbose:
                    log += [f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})']
                    log += [f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})']
                    log += [f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})']
                log += [f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})']
                if training:
                    lr = self.optimizer.param_groups[0]['lr']
                    log += [f'LR {lr:.3e}']
                log = '\t'.join(log)
                logging.info(log)

            save_chkpt = (self.save_counter % self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, data_loader, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param data_loader: data loader
        :param training: if True preallocates memory for backward pass
        """
        batch_size = data_loader.batch_size
        max_len = data_loader.dataset.max_len

        src_length = [max_len] * batch_size
        tgt_length = [max_len] * batch_size

        if self.batch_first:
            shape = (batch_size, max_len)
        else:
            shape = (max_len, batch_size)

        src = torch.full(shape, 4, dtype=torch.int64)
        tgt = torch.full(shape, 4, dtype=torch.int64)
        src = src, src_length
        tgt = tgt, tgt_length
        self.iterate(src, tgt, update=False, training=training)
        self.model.zero_grad()

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=True)
        output = self.feed_data(data_loader, training=True)
        self.model.zero_grad()
        torch.cuda.empty_cache()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=False)
        output = self.feed_data(data_loader, training=False)
        self.model.zero_grad()
        torch.cuda.empty_cache()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            if self.distributed:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """

        def write_checkpoint(state, filename):
            filename = os.path.join(self.save_path, filename)
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        if self.distributed:
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        state = {
            'epoch': self.epoch,
            'state_dict': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)
Пример #6
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file if args.rank == 0 else None,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])

    LOGGER.timed_block_start("run")
    LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("iter_time", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("epoch_time", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("run_time", metric_scope=dllg.RUN_SCOPE)
    LOGGER.register_metric("val_iter_loss", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_loss",
                           metric_scope=dllg.EPOCH_SCOPE)

    log_hardware()

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    parser.parse_args()

    args = parser.parse_args()

    log_args(args)

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    distributed_run = args.world_size > 1
    if distributed_run:
        init_distributed(args, args.world_size, args.rank, args.group_name)

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    model_config = models.get_model_config(model_name, args)
    model = models.get_model(model_name,
                             model_config,
                             to_cuda=True,
                             uniform_initialize_bn_weight=not args.
                             disable_uniform_initialize_bn_weight)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        if distributed_run:
            model = DDP(model)

    if args.checkpoint != "":
        checkpoint = torch.load(args.checkpoint)
        state_dict = checkpoint['state_dict']

        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)

        if args.amp_run:
            amp.load_state_dict(checkpoint['amp'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(state_dict)
        print("Loaded from checkpoint: %s !" % args.checkpoint)

    if not args.amp_run and distributed_run:
        model = DDP(model)

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    criterion = loss_functions.get_loss_function(model_name, sigma)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.dataset_path,
                                              args.training_files, args)
    train_sampler = DistributedSampler(trainset) if distributed_run else None
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.dataset_path,
                                            args.validation_files, args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    model.train()

    LOGGER.log(key=tags.TRAIN_LOOP)

    for epoch in range(args.epochs):
        LOGGER.epoch_start()
        epoch_start_time = time.time()
        LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)

        # used to calculate avg items/sec over epoch
        reduced_num_items_epoch = 0

        # used to calculate avg loss over epoch
        train_epoch_avg_loss = 0.0
        train_epoch_avg_items_per_sec = 0.0
        num_iters = 0

        # if overflow at the last iteration then do not save checkpoint
        overflow = False

        for i, batch in enumerate(train_loader):
            print("Batch: {}/{} epoch {}".format(i, len(train_loader), epoch))
            LOGGER.iteration_start()
            iter_start_time = time.time()
            LOGGER.log(key=tags.TRAIN_ITER_START, value=i)

            start = time.perf_counter()
            adjust_learning_rate(epoch, optimizer, args.learning_rate,
                                 args.anneal_steps, args.anneal_factor)

            model.zero_grad()
            x, y, num_items = batch_to_gpu(batch)

            y_pred = model(x)
            loss = criterion(y_pred, y)

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_items = reduce_tensor(num_items.data, 1).item()
            else:
                reduced_loss = loss.item()
                reduced_num_items = num_items.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)

            train_epoch_avg_loss += reduced_loss
            num_iters += 1

            # accumulate number of items processed in this epoch
            reduced_num_items_epoch += reduced_num_items

            if args.amp_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            iteration += 1

            LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)

            iter_stop_time = time.time()
            iter_time = iter_stop_time - iter_start_time
            items_per_sec = reduced_num_items / iter_time
            train_epoch_avg_items_per_sec += items_per_sec

            LOGGER.log(key="train_iter_items/sec", value=items_per_sec)
            LOGGER.log(key="iter_time", value=iter_time)
            LOGGER.iteration_stop()

        LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        LOGGER.log(key="train_epoch_items/sec",
                   value=(reduced_num_items_epoch / epoch_time))
        LOGGER.log(key="train_epoch_avg_items/sec",
                   value=(train_epoch_avg_items_per_sec /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="train_epoch_avg_loss",
                   value=(train_epoch_avg_loss /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="epoch_time", value=epoch_time)

        LOGGER.log(key=tags.EVAL_START, value=epoch)

        validate(model, criterion, valset, iteration, args.batch_size,
                 args.world_size, collate_fn, distributed_run, args.rank,
                 batch_to_gpu)

        LOGGER.log(key=tags.EVAL_STOP, value=epoch)

        if (epoch % args.epochs_per_checkpoint == 0) and args.rank == 0:
            checkpoint_path = os.path.join(
                args.output_directory,
                "checkpoint_{}_{}".format(model_name, epoch))
            save_checkpoint(model, optimizer, epoch, model_config,
                            checkpoint_path, amp if args.amp_run else None)
            save_sample(
                model_name, model, args.waveglow_checkpoint,
                args.tacotron2_checkpoint, args.phrase_path,
                os.path.join(args.output_directory,
                             "sample_{}_{}.wav".format(model_name, iteration)),
                args.sampling_rate)

        LOGGER.epoch_stop()

    run_stop_time = time.time()
    run_time = run_stop_time - run_start_time
    LOGGER.log(key="run_time", value=run_time)
    LOGGER.log(key=tags.RUN_FINAL)

    print("training time", run_stop_time - run_start_time)

    LOGGER.timed_block_stop("run")

    if args.rank == 0:
        LOGGER.finish()
Пример #7
0
def main():
    global best_top1, best_top5

    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
    valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.train_batch,
                                               shuffle=(train_sampler is None),
                                               pin_memory=True,
                                               num_workers=8,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(valid_data,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=8)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            baseWidth=args.base_width,
            cardinality=args.cardinality,
        )
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = DDP(model.features)
        model.cuda()
    else:
        model = model.cuda()
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adamw':
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay,
                          warmup=0)
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'lsadam':
        optimizer = LSAdamW(model.parameters(),
                            lr=args.lr * ((1. + 4. * args.sigma)**(0.25)),
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.weight_decay,
                            sigma=args.sigma)
    elif args.optimizer.lower() == 'lsradam':
        sigma = 0.1
        optimizer = LSRAdam(model.parameters(),
                            lr=args.lr * ((1. + 4. * args.sigma)**(0.25)),
                            betas=(args.beta1, args.beta2),
                            weight_decay=args.weight_decay,
                            sigma=args.sigma)
    elif args.optimizer.lower() == 'srsgd':
        iter_count = 1
        optimizer = SGD_Adaptive(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 iter_count=iter_count,
                                 restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradam':
        iter_count = 1
        optimizer = SRNAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradamw':
        iter_count = 1
        optimizer = SRAdamW(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=0,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'srradam':
        #NOTE: need to double-check this
        iter_count = 1
        optimizer = SRRAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=0,
                            restarting_iter=args.restart_schedule[0])

    schedule_index = 1
    # Resume
    title = 'ImageNet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        # args.checkpoint = os.path.dirname(args.resume)
        # checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.local_rank))
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        best_top1 = checkpoint['best_top1']
        best_top5 = checkpoint['best_top5']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            iter_count = optimizer.param_groups[0]['iter_count']
        schedule_index = checkpoint['schedule_index']
        state['lr'] = optimizer.param_groups[0]['lr']
        if args.checkpoint == args.resume:
            logger = LoggerDistributed(os.path.join(args.checkpoint,
                                                    'log.txt'),
                                       rank=args.local_rank,
                                       title=title,
                                       resume=True)
        else:
            logger = LoggerDistributed(os.path.join(args.checkpoint,
                                                    'log.txt'),
                                       rank=args.local_rank,
                                       title=title)
            if args.local_rank == 0:
                logger.set_names([
                    'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Top1',
                    'Valid Top1', 'Train Top5', 'Valid Top5'
                ])
    else:
        logger = LoggerDistributed(os.path.join(args.checkpoint, 'log.txt'),
                                   rank=args.local_rank,
                                   title=title)
        if args.local_rank == 0:
            logger.set_names([
                'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Top1',
                'Valid Top1', 'Train Top5', 'Valid Top5'
            ])

    if args.local_rank == 0:
        logger.file.write('    Total params: %.2fM' %
                          (sum(p.numel()
                               for p in model.parameters()) / 1000000.0))

    if args.evaluate:
        if args.local_rank == 0:
            logger.file.write('\nEvaluation only')
        test_loss, test_top1, test_top5 = test(val_loader, model, criterion,
                                               start_epoch, use_cuda, logger)
        if args.local_rank == 0:
            logger.file.write(
                ' Test Loss:  %.8f, Test Top1:  %.2f, Test Top5: %.2f' %
                (test_loss, test_top1, test_top5))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        # Shuffle the sampler.
        train_loader.sampler.set_epoch(epoch + args.manualSeed)

        if args.optimizer.lower() == 'srsgd':
            if epoch in args.schedule:
                optimizer = SGD_Adaptive(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    weight_decay=args.weight_decay,
                    iter_count=iter_count,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'sradam':
            if epoch in args.schedule:
                optimizer = SRNAdam(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'sradamw':
            if epoch in args.schedule:
                optimizer = SRAdamW(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    warmup=0,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        elif args.optimizer.lower() == 'srradam':
            if epoch in args.schedule:
                optimizer = SRRAdam(
                    model.parameters(),
                    lr=args.lr * (args.gamma**schedule_index),
                    betas=(args.beta1, args.beta2),
                    iter_count=iter_count,
                    weight_decay=args.weight_decay,
                    warmup=0,
                    restarting_iter=args.restart_schedule[schedule_index])
                schedule_index += 1

        else:
            adjust_learning_rate(optimizer, epoch)

        if args.local_rank == 0:
            logger.file.write('\nEpoch: [%d | %d] LR: %f' %
                              (epoch + 1, args.epochs, state['lr']))

        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            train_loss, train_top1, train_top5, iter_count = train(
                train_loader, model, criterion, optimizer, epoch, use_cuda,
                logger)
        else:
            train_loss, train_top1, train_top5 = train(train_loader, model,
                                                       criterion, optimizer,
                                                       epoch, use_cuda, logger)

        test_loss, test_top1, test_top5 = test(val_loader, model, criterion,
                                               epoch, use_cuda, logger)

        # append logger file
        if args.local_rank == 0:
            logger.append([
                state['lr'], train_loss, test_loss, train_top1, test_top1,
                train_top5, test_top5
            ])
            writer.add_scalars('train_loss', {args.model_name: train_loss},
                               epoch)
            writer.add_scalars('test_loss', {args.model_name: test_loss},
                               epoch)
            writer.add_scalars('train_top1', {args.model_name: train_top1},
                               epoch)
            writer.add_scalars('test_top1', {args.model_name: test_top1},
                               epoch)
            writer.add_scalars('train_top5', {args.model_name: train_top5},
                               epoch)
            writer.add_scalars('test_top5', {args.model_name: test_top5},
                               epoch)

        # save model
        is_best = test_top1 > best_top1
        best_top1 = max(test_top1, best_top1)
        best_top5 = max(test_top5, best_top5)
        if args.local_rank == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'schedule_index': schedule_index,
                    'state_dict': model.state_dict(),
                    'top1': test_top1,
                    'top5': test_top5,
                    'best_top1': best_top1,
                    'best_top5': best_top5,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                epoch,
                checkpoint=args.checkpoint)

            if epoch == args.schedule[-1]:
                logger.file.write('Best top1: %f at epoch %i' %
                                  (best_top1, epoch))
                logger.file.write('Best top5: %f at epoch %i' %
                                  (best_top5, epoch))
                print('Best top1: %f at epoch %i' % (best_top1, epoch))
                print('Best top5: %f at epoch %i' % (best_top5, epoch))
                with open("./all_results_imagenet.txt", "a") as f:
                    fcntl.flock(f, fcntl.LOCK_EX)
                    f.write("%s\n" % args.checkpoint)
                    f.write("best_top1 %f, best_top5 %f at epoch %i\n\n" %
                            (best_top1, best_top5, epoch))
                    fcntl.flock(f, fcntl.LOCK_UN)

    if args.local_rank == 0:
        logger.file.write('Best top1: %f' % best_top1)
        logger.file.write('Best top5: %f' % best_top5)
        logger.close()
        logger.plot()
        savefig(os.path.join(args.checkpoint, 'log.eps'))
        print('Best top1: %f' % best_top1)
        print('Best top5: %f' % best_top5)
        with open("./all_results_imagenet.txt", "a") as f:
            fcntl.flock(f, fcntl.LOCK_EX)
            f.write("%s\n" % args.checkpoint)
            f.write("best_top1 %f, best_top5 %f\n\n" % (best_top1, best_top5))
            fcntl.flock(f, fcntl.LOCK_UN)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor,
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification",
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels),
                                    label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            # create eval loss and other metric required by the task
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                         label_ids.view(-1))
            elif output_mode == "regression":
                loss_fct = MSELoss()
                tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0],
                                     logits.detach().cpu().numpy(),
                                     axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(task_name, preds, all_label_ids.numpy())
        loss = tr_loss / nb_tr_steps if args.do_train else None

        result['eval_loss'] = eval_loss
        result['global_step'] = global_step
        result['loss'] = loss

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

        # hack for MNLI-MM
        if task_name == "mnli":
            task_name = "mnli-mm"
            processor = processors[task_name]()

            if os.path.exists(args.output_dir +
                              '-MM') and os.listdir(args.output_dir +
                                                    '-MM') and args.do_train:
                raise ValueError(
                    "Output directory ({}) already exists and is not empty.".
                    format(args.output_dir))
            if not os.path.exists(args.output_dir + '-MM'):
                os.makedirs(args.output_dir + '-MM')

            eval_examples = processor.get_dev_examples(args.data_dir)
            eval_features = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer,
                output_mode)
            logger.info("***** Running evaluation *****")
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)
            all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                         dtype=torch.long)
            all_input_mask = torch.tensor(
                [f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor(
                [f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_label_ids)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.eval_batch_size)

            model.eval()
            eval_loss = 0
            nb_eval_steps = 0
            preds = []

            for input_ids, input_mask, segment_ids, label_ids in tqdm(
                    eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   labels=None)

                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                         label_ids.view(-1))

                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_steps += 1
                if len(preds) == 0:
                    preds.append(logits.detach().cpu().numpy())
                else:
                    preds[0] = np.append(preds[0],
                                         logits.detach().cpu().numpy(),
                                         axis=0)

            eval_loss = eval_loss / nb_eval_steps
            preds = preds[0]
            preds = np.argmax(preds, axis=1)
            result = compute_metrics(task_name, preds, all_label_ids.numpy())
            loss = tr_loss / nb_tr_steps if args.do_train else None

            result['eval_loss'] = eval_loss
            result['global_step'] = global_step
            result['loss'] = loss

            output_eval_file = os.path.join(args.output_dir + '-MM',
                                            "eval_results.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
Пример #9
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        accuracy = 0.0
        logger.debug(conf)
        if conf.use_mobilfacenet:
            # self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            self.model = MobileFaceNet(conf.embedding_size).cuda()
            logger.debug('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).cuda()  #.to(conf.device)
            logger.debug('{}_{} model generated'.format(
                conf.net_mode, conf.net_depth))
        if not inference:
            self.milestones = conf.milestones
            logger.info('loading data...')
            self.loader_tri, self.class_num_tri = get_train_loader(
                conf, 'emore', sample_identity=True)

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            self.head_tri = Triplet().cuda()
            logger.debug('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD([{
                    'params': paras_wo_bn[:-1],
                    'weight_decay': 4e-5
                }, {
                    'params': [paras_wo_bn[-1]],
                    'weight_decay': 4e-4
                }, {
                    'params': paras_only_bn
                }],
                                           lr=conf.lr,
                                           momentum=conf.momentum)
            else:
                self.optimizer = optim.SGD([{
                    'params': paras_wo_bn,
                    'weight_decay': 5e-4
                }, {
                    'params': paras_only_bn
                }],
                                           lr=conf.lr,
                                           momentum=conf.momentum)
            # self.optimizer = torch.nn.parallel.DistributedDataParallel(optimizer,device_ids=[conf.argsed])
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            if conf.fp16:
                self.model, self.optimizer = amp.initialize(self.model,
                                                            self.optimizer,
                                                            opt_level="O2")
                self.model = DistributedDataParallel(self.model).cuda()
            else:
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model,
                    device_ids=[conf.argsed],
                    find_unused_parameters=True).cuda(
                    )  #add line for distributed

            self.board_loss_every = len(self.loader_tri) // 100
            self.evaluate_every = len(self.loader_tri) // 20
            self.save_every = len(self.loader_tri) // 2
            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(
                Path(self.loader_tri.dataset.root).parent)
        else:
            self.threshold = conf.threshold
            self.loader, self.query_ds, self.gallery_ds = get_test_loader(conf)

    def save_state(self,
                   conf,
                   epoch,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        if not os.path.exists(save_path):
            os.makedirs(save_path, exist_ok=True)

        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_{}_acc:{:.4f}_{}.pth'.format(
                epoch, self.step, accuracy, extra)))
        if not model_only:
            torch.save(
                self.optimizer.state_dict(),
                save_path / ('optimizer_{}_{}_acc:{:.4f}_{}.pth'.format(
                    epoch, self.step, accuracy, extra)))

    def load_network(self, conf, save_path):
        state_dict = torch.load(save_path,
                                map_location='cuda:{}'.format(conf.local_rank))
        # create new OrderedDict that does not contain `module.`
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # logger.debug('key {}'.format(k))
            namekey = k[7:]
            # logger.debug('key {}'.format(namekey))  # remove 'module.'
            new_state_dict[namekey] = v
        # load params
        return new_state_dict

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        if conf.resume:
            self.model.load_state_dict(
                torch.load(save_path / 'model_{}'.format(fixed_str),
                           map_location='cuda:{}'.format(conf.local_rank)))
        else:
            self.model.load_state_dict(
                self.load_network(conf,
                                  save_path / 'model_{}'.format(fixed_str)))

        if not model_only:
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
            logger.info('load optimizer {}'.format(self.optimizer))
            # amp.load_state_dict(torch.load(save_path / 'amp_{}'.format(fixed_str)))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda())[1] + self.model(
                        fliped.cuda())[1]
                    embeddings[idx:idx +
                               conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings[idx:idx + conf.batch_size] = l2_norm(
                        self.model(batch.cuda())[1]).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda())[1] + self.model(
                        fliped.cuda())[1]
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = l2_norm(self.model(
                        batch.cuda())[1]).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                       nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    # true top 1, false top 1, miss
    def compute_true_false_miss(self, conf, log_dir, feat_path, tta):
        def gen_distmat(qf, q_pids, gf, g_pids):
            m, n = qf.shape[0], gf.shape[0]
            logger.debug('query shape {}, gallery shape {}'.format(
                qf.shape, gf.shape))
            # logger.debug('q_pids {}, g_pids {}'.format(q_pids, g_pids))
            distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                      torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
            distmat.addmm_(1, -2, qf, gf.t())
            distmat = distmat.cpu().numpy()
            return distmat

        def distance(emb1, emb2):
            diff = np.subtract(emb1, emb2)
            dist = np.sum(np.square(diff), 1)
            return dist

        self.model.eval()
        if conf.gen_feature:
            with torch.no_grad():
                query_feature, query_label = extract_feature(
                    conf, self.model, self.loader['query']['dl'], tta)
                gallery_feature, gallery_label = extract_feature(
                    conf, self.model, self.loader['gallery']['dl'], tta)
            # result = {'query_feature': query_feature.numpy(), 'query_label': query_label,
            #     'gallery_feature': gallery_feature.numpy(), 'gallery_label': gallery_label}

            result = {
                'query_feature': query_feature.numpy(),
                'query_label': query_label.numpy(),
                'gallery_feature': gallery_feature.numpy(),
                'gallery_label': gallery_label.numpy()
            }
            scipy.io.savemat(feat_path, result)

        else:
            result = scipy.io.loadmat(feat_path)
            query_feature = torch.from_numpy(result['query_feature'])
            query_label = torch.from_numpy(result['query_label'])[0]
            gallery_feature = torch.from_numpy(result['gallery_feature'])
            gallery_label = torch.from_numpy(result['gallery_label'])[0]

        distmat = gen_distmat(query_feature, query_label, gallery_feature,
                              gallery_label)

        # record txt
        with open(os.path.join(log_dir, 'result.txt'), 'at') as f:
            f.write('%s\t%s\t%s\t%s\n' % ('threshold', 'acc', 'err', 'miss'))

        # record excel
        xls_file = xlwt.Workbook()
        sheet_1 = xls_file.add_sheet('sheet_1', cell_overwrite_ok=True)
        row = 0
        path_excel = os.path.join(log_dir, 'result.xls')

        sheet_title = ['threshold', 'acc', 'err', 'miss']
        for i_sheet in range(len(sheet_title)):
            sheet_1.write(row, i_sheet, sheet_title[i_sheet])
        xls_file.save(path_excel)
        row += 1

        index = np.argsort(distmat)  # from small to large
        max_index = index[:, 0]

        query_list_file = 'data/probe.txt'
        gallery_list_file = 'data/gallery.txt'
        err_rank1 = os.path.join(log_dir, 'err_rank1.txt')
        data_path = DataPath(query_list_file, gallery_list_file)
        with open(err_rank1, 'at') as f:
            f.write('%s\t\t\t%s\n' % ('query', 'gallery'))

        thresholds = np.arange(0.4, 2, 0.01)
        for threshold in thresholds:
            acc, err, miss = compute_rank1(distmat, max_index, query_label,
                                           gallery_label, threshold, data_path,
                                           err_rank1)
            # record txt
            with open(os.path.join(log_dir, 'result.txt'), 'at') as f:
                f.write('%.6f\t%.6f\t%.6f\t%.6f\n' %
                        (threshold, acc, err, miss))

            # record excel
            list_data = [threshold, acc, err, miss]
            for i_1 in range(len(list_data)):
                sheet_1.write(row, i_1, list_data[i_1])
            xls_file.save(path_excel)
            row += 1

    def train(self, conf, epochs):
        self.model.train()
        # logger.debug('model {}'.format(self.model))
        running_loss = 0.

        # 断点加载训练
        if conf.resume:
            logger.debug('resume...')
            self.load_state(conf, 'ir_se100.pth', from_save_folder=True)

        logger.debug('optimizer {}'.format(self.optimizer))
        for epoch in range(epochs):
            logger.debug('epoch {} started'.format(epoch))
            for data_tri in tqdm(iter(self.loader_tri)):
                if self.step in self.milestones:
                    self.schedule_lr()
                imgs_tri, labels_tri = data_tri
                imgs_tri = imgs_tri.cuda()
                labels_tri = labels_tri.cuda()
                self.optimizer.zero_grad()
                # embeddings_tri, _ = self.model(imgs_tri)
                _, embeddings_tri = self.model(imgs_tri)
                loss = self.head_tri(embeddings_tri, labels_tri)
                if conf.fp16:  # we use optimier to backward loss
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                running_loss += loss.item()
                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  #comment line
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:  #comment line
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.agedb_30, self.agedb_30_issame)
                    self.board_val('agedb_30', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.cfp_fp, self.cfp_fp_issame)
                    self.board_val('cfp_fp', accuracy, best_threshold,
                                   roc_curve_tensor)
                    # logger.debug('optimizer {}'.format(self.optimizer))
                    logger.debug(
                        'epoch {}, step {}, loss {:.4f}, acc {:.4f}'.format(
                            epoch, self.step, loss.item(), accuracy))
                    self.model.train()

                if conf.local_rank == 0 and epoch >= 10 and self.step % self.save_every == 0 and self.step != 0:
                    # if conf.local_rank == 0 and self.step % self.save_every == 0 and self.step != 0:
                    self.save_state(conf, epoch, accuracy)

                self.step += 1

        self.save_state(conf,
                        epoch,
                        accuracy,
                        to_save_folder=True,
                        extra='final')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        logger.debug('optimizer {}'.format(self.optimizer))

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum
Пример #10
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument(
        'data',
        metavar='DIR',
        nargs='*',
        help='path(s) to dataset (if one path is provided, it is assumed\n' +
        'to have subdirectories named "train" and "val"; alternatively,\n' +
        'train and val paths can be specified directly by providing both paths as arguments)'
    )
    parser.add_argument('-a',
                        '--arch',
                        metavar='ARCH',
                        default='resnet18',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                        ' (default: resnet18)')
    parser.add_argument('-j',
                        '--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs',
                        default=60,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument(
        '-bs',
        '--batch-size',
        default=128,
        type=int,
        metavar='N',
        help='batch size for descriptor generation (default: 128)')
    parser.add_argument('-lr',
                        '--learning-rate',
                        default=0.1,
                        type=float,
                        metavar='LR',
                        help='initial learning rate',
                        dest='lr')
    parser.add_argument('--momentum',
                        default=0.9,
                        type=float,
                        metavar='M',
                        help='momentum')
    parser.add_argument('--wd',
                        '--weight-decay',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('-p',
                        '--print-freq',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    #parser.add_argument('--evaluate', dest='evaluate', action='store_true',
    #                    help='evaluate model on validation set')
    parser.add_argument('--fp16',
                        action='store_true',
                        help='Run model fp16 mode.')
    parser.add_argument('--dali_cpu',
                        action='store_true',
                        help='Runs CPU based version of DALI pipeline.')
    parser.add_argument(
        '--static-loss-scale',
        type=float,
        default=1,
        help=
        'Static loss scale, positive power of 2 values can improve fp16 convergence.'
    )
    parser.add_argument(
        '--dynamic-loss-scale',
        action='store_true',
        help='Use dynamic loss scaling.  If supplied, this argument supersedes '
        + '--static-loss-scale.')
    parser.add_argument('--prof',
                        dest='prof',
                        action='store_true',
                        help='Only run 10 iterations for profiling.')
    parser.add_argument('-t',
                        '--test',
                        action='store_true',
                        help='Launch test mode with preset arguments')

    parser.add_argument("--local_rank", default=0, type=int)
    # added
    parser.add_argument('-ts',
                        '--train-size',
                        type=int,
                        default=0,
                        metavar='N',
                        help='number of examples for training (default: 0)')
    parser.add_argument(
        '-ir',
        '--imbalance-ratio',
        type=int,
        default=1,
        metavar='N',
        help=
        'ratio of 0..499 to 500..999 labels in the training dataset drawn from uniform distribution'
    )
    parser.add_argument(
        '-nr',
        '--noisy-ratio',
        type=float,
        default=0.0,
        metavar='N',
        help=
        'ratio of noisy(random) labels in the training dataset drawn from uniform distribution'
    )
    parser.add_argument(
        '-ens',
        '--ensemble-size',
        type=int,
        default=1,
        metavar='E',
        help='defines size of ensemble or, by default, no ensemble if = 1')
    parser.add_argument('-e',
                        '--ensemble-index',
                        type=int,
                        default=0,
                        metavar='E',
                        help='defines index of ensemble')
    parser.add_argument('--save-folder',
                        default='../local_data/ImageNet',
                        type=str,
                        help='dir to save data')
    parser.add_argument('-r',
                        '--run-folder',
                        default='run99',
                        type=str,
                        help='dir to save run')
    parser.add_argument('-b',
                        '--batch',
                        type=int,
                        default=0,
                        metavar='N',
                        help='augmentation batch (iteration) (default: 0)')
    parser.add_argument(
        '-as',
        '--augment-size',
        type=int,
        default=64000,
        metavar='N',
        help='augmentation dataset size for training (default: 64000)')
    parser.add_argument(
        '-sub',
        '--subtype-method',
        type=str,
        default='grad',
        metavar='N',
        help='method to generate gradient information (default: grad)')
    parser.add_argument('-aug',
                        '--augment-method',
                        type=str,
                        default='random',
                        metavar='N',
                        help='method to match distributions (default: random)')
    parser.add_argument('-dl',
                        '--descriptor-length',
                        type=int,
                        default=0,
                        metavar='L',
                        help='descriptor length (default: 0)')
    parser.add_argument(
        '-unsup',
        '--unsupervised',
        type=int,
        default=0,
        help='unsupervised pretraining as initial step or random weights')

    args = parser.parse_args()
    cudnn.benchmark = True

    # test mode, use default args for sanity test
    if args.test:
        args.fp16 = False
        args.epochs = 1
        args.start_epoch = 0
        args.arch = 'resnet18'
        args.batch_size = 256
        args.data = []
        args.prof = True
        args.data.append('/data/imagenet/train-jpeg/')
        args.data.append('/data/imagenet/val-jpeg/')

    if not len(args.data):
        raise Exception("error: too few data arguments")

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1
    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # Data loading code
    if len(args.data) == 1:
        train_dir = os.path.join(args.data[0], 'train')
        val_dir = os.path.join(args.data[0], 'val')
    else:
        train_dir = args.data[0]
        val_dir = args.data[1]

    if (args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320  # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256
    #
    print('Running main_descr.py with {} {} {} {} {}'.format(
        args.batch, args.train_size, args.augment_size, args.subtype_method,
        args.augment_method))
    # pipe for val dataset
    val_list_file = '{}/{}'.format(args.save_folder, 'processed/val_list.txt')
    pipe = HybridValPipe(batch_size=args.batch_size,
                         num_threads=args.workers,
                         device_id=args.local_rank,
                         data_dir=val_dir,
                         file_list=val_list_file,
                         crop=crop_size,
                         local_rank=args.local_rank,
                         world_size=args.world_size,
                         size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    # create model
    print("=> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch](L=args.descriptor_length)
    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        # shared param/delay all reduce turns off bucketing in DDP, for lower latency runs this can improve perf
        # for the older version of APEX please use shared_param, for newer one it is delay_allreduce
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # load checkpoint
    model_folder = '{}/{}/checkpoint'.format(args.save_folder, args.run_folder)
    descr_folder = '{}/{}/descr'.format(args.save_folder, args.run_folder)
    assert os.path.isdir(
        model_folder), 'Error: no model checkpoint directory found!'
    assert os.path.isdir(descr_folder), 'Error: no descriptor directory found!'
    # load existing
    if args.unsupervised == 1:
        unsup_prefix = 'unsup_'
    else:
        unsup_prefix = ''
    #
    descr_postfix = '{}batch_B_ir_{}_nr_{}_sub_{}_aug_{}_L_{}'.format(
        unsup_prefix,  # we do not save descriptors for each iteration here due to large size
        args.imbalance_ratio,
        args.noisy_ratio,
        args.subtype_method,
        args.augment_method,
        args.descriptor_length)
    if args.batch == 0:
        model_postfix = '{}batch_{}_ir_{}_nr_{}_sub_{}_aug_{}'.format(
            unsup_prefix, args.batch, args.imbalance_ratio, args.noisy_ratio,
            'none', 'none')
        if args.ensemble_size > 1:
            checkpoint_file = '{}/init_{}_E_{}.pt'.format(
                model_folder, model_postfix, args.ensemble_index)
        else:
            checkpoint_file = '{}/init_{}.pt'.format(model_folder,
                                                     model_postfix)
    else:
        model_postfix = '{}batch_{}_size_{}_ir_{}_nr_{}_sub_{}_aug_{}_L_{}'.format(
            unsup_prefix, args.batch, args.train_size, args.imbalance_ratio,
            args.noisy_ratio, args.subtype_method, args.augment_method,
            args.descriptor_length)
        if args.ensemble_size > 1:
            checkpoint_file = '{}/best_{}_E_{}.pt'.format(
                model_folder, model_postfix, args.ensemble_index)
        else:
            checkpoint_file = '{}/best_{}.pt'.format(model_folder,
                                                     model_postfix)
    #
    if args.ensemble_size > 1:
        index_use_file = '{}/{}/index_list_{}_E_{}.npy'.format(
            args.save_folder, args.run_folder, model_postfix,
            args.ensemble_size)
    else:
        index_use_file = '{}/{}/index_list_{}.npy'.format(
            args.save_folder, args.run_folder, model_postfix)
    #
    if (args.imbalance_ratio == 1) and (args.noisy_ratio
                                        == 0.0):  # use original train dataset
        full_train_list_file = '{}/processed/train_list.txt'.format(
            args.save_folder)
    else:
        full_train_list_file = '{}/{}/full_train_list_ir_{}_nr_{}.txt'.format(
            args.save_folder, args.run_folder, args.imbalance_ratio,
            args.noisy_ratio)
    #
    with open(full_train_list_file) as f:
        full_train_list = f.readlines()
    full_train_list = [l.strip() for l in full_train_list]

    index_all = range(len(full_train_list))
    set_all = set(index_all)

    if os.path.isfile(checkpoint_file) and (os.path.isfile(index_use_file) or
                                            (args.batch == 0)):
        checkpoint = torch.load(checkpoint_file)
        if args.batch == 0 and args.unsupervised == 1:
            model.load_state_dict(
                {
                    k: v
                    for k, v in checkpoint['state_dict'].items()
                    if 'fc' not in k
                },
                strict=False)  # copy all but last linear layer!
        else:
            model.load_state_dict(checkpoint['state_dict'])
        #
        if args.batch == 0:
            index_use = []
        else:
            index_use = np.load(index_use_file)
        #
        set_use = set(index_use)
        set_avail = set_all - set_use
        index_avail = list(set_avail)
        avail_size = len(index_avail)
        # checks:
        assert len(index_use
                   ) == args.train_size, 'Number of used examples should match'
    else:
        print('Some checkpoint or index files are missing in main_descr.py!')
        sys.exit(0)
    #
    val_prefix = 'val'
    train_prefix = 'train'
    if args.ensemble_size > 1:
        descr_train_file = '{}/{}_{}_E_{}.pt'.format(descr_folder,
                                                     train_prefix,
                                                     descr_postfix,
                                                     args.ensemble_index)
        print('Descriptor files =', descr_train_file)
    else:
        descr_train_file = '{}/{}_{}.pt'.format(descr_folder, train_prefix,
                                                descr_postfix)
        descr_val_file = '{}/{}_{}.pt'.format(descr_folder, val_prefix,
                                              descr_postfix)
        print('Descriptor files =', descr_train_file, descr_val_file)
    # augmentaion methods
    if args.augment_method == 'random':
        print('No descriptors needed in random augmentation case!')
        topindex_sim = random.sample(range(avail_size), args.augment_size)
    elif ('topUncert' in args.augment_method) and (
            args.ensemble_size == 1) and os.path.isfile(descr_train_file):
        print('Uncertainty-based without ensembling!')
        if args.batch == 0:  # for some reason random sampling is better initially (same setup as paper epxeriments)
            print(
                '!!!!!!!!!!!!!!!!!!!!! Random iteration when b=0 !!!!!!!!!!!!!!!!!!!!!'
            )
            topindex_sim = random.sample(range(avail_size), args.augment_size)
        else:
            descr_train = torch.load(descr_train_file)
            assert descr_train.size(0) == len(
                full_train_list
            ), 'Number of train descriptors should be equal to number of entries in full train file'
            descr_avail = descr_train[index_avail]
            print(descr_avail.size(), args.augment_size)
            _, topkindex_sim = torch.topk(descr_avail,
                                          args.augment_size,
                                          largest=True)
            topindex_sim = topkindex_sim.tolist()
    elif ('topUncert' in args.augment_method) and (
            args.ensemble_size > 1) and os.path.isfile(descr_train_file):
        print('Uncertainty-based with ensembling!')
        ensemble_index_file = '{}/{}/ensemble_list_{}_E_{}.npy'.format(
            args.save_folder, args.run_folder, model_postfix,
            args.ensemble_size)
        if args.batch == 0:  # for some reason random sampling is better initially (same setup as paper epxeriments)
            print(
                '!!!!!!!!!!!!!!!!!!!!! Random iteration when b=0 !!!!!!!!!!!!!!!!!!!!!'
            )
            topindex_sim = random.sample(range(avail_size), args.augment_size)
        else:
            if args.ensemble_index == 0:  # generate augmentation list only once
                descr_train = torch.load(descr_train_file)
                assert descr_train.size(0) == len(
                    full_train_list
                ), 'Number of train descriptors should be equal to number of entries in full train file'
                descr_avail = descr_train[index_avail].cuda()  # PxCx2
                for e in range(1,
                               args.ensemble_size):  # average ensemble results
                    descr_train_file = '{}/{}_{}_E_{}.pt'.format(
                        descr_folder, train_prefix, descr_postfix, e)
                    if os.path.isfile(descr_train_file):
                        descr_train = torch.load(descr_train_file)
                        descr_avail = descr_avail + descr_train[
                            index_avail].cuda()
                    else:
                        print(
                            'Some descriptor files are missing in ensemble-based methods'
                        )
                        sys.exit(0)
                #
                P = descr_avail.size(0)
                r = 1e-8  # regularization for log() to avoid nan
                if 'ent' in args.subtype_method:
                    pT = descr_avail[:, :, 0]  # PxC
                    max_entropy = -torch.sum(torch.mul(pT, torch.log2(pT + r)),
                                             1)  # P
                    f = max_entropy
                elif 'bald' in args.subtype_method:
                    pT = descr_avail[:, :, 0]  # PxC
                    pL = descr_avail[:, :, 1]  # PxC
                    max_entropy = -torch.sum(torch.mul(pT, torch.log2(pT + r)),
                                             1)  # P
                    bald = max_entropy + torch.sum(pL, 1)  # P
                    f = bald
                elif 'var' in args.subtype_method:
                    pT = descr_avail[:, :, 0]  # PxC
                    fMax, _ = torch.max(pT, 1)
                    var_ratio = 1 - fMax / args.ensemble_size
                    f = var_ratio
                else:
                    sys.exit(0)
                    print('Wrong ensemble uncert method!')
                #
                fCpu = f.cpu()
                _, topkindex_sim = torch.topk(fCpu,
                                              args.augment_size,
                                              largest=True)
                topindex_sim = topkindex_sim.tolist()
                np.save(ensemble_index_file, topindex_sim)
            else:  # reuse existing ensemble list
                topindex_sim = np.load(ensemble_index_file)
    elif ('topK' in args.augment_method) and os.path.isfile(
            descr_train_file) and os.path.isfile(descr_val_file):
        descr_train = torch.load(descr_train_file)
        assert descr_train.size(0) == len(
            full_train_list
        ), 'Number of train descriptors should be equal to number of entries in full train file'
        descr_val = torch.load(descr_val_file)
        assert descr_val.size(
            0) == V, 'Number of val descriptors should be equal to 50000'
        index_miss = get_miss(args, val_loader, model, criterion)
        index_miss = [v for v in index_miss
                      if v < V]  # workaround for DALI bug
        val_loader.reset()
        descr_miss = descr_val[index_miss].cuda()
        descr_avail = descr_train[index_avail].cuda()
        # some constants to calculate sub similarity matrices
        L = args.descriptor_length
        M = descr_miss.size(0)
        P = descr_avail.size(0)
        K = 128  # calculate extra because of potential overlap
        S = 128  # divide MxL matrix into SxL chunks to fit into memory
        print('Feature-matching augmentation with descriptors:',
              len(index_avail), descr_val.size(), descr_miss.size(),
              descr_train.size(), descr_avail.size())
        # make a list for multiscale attention
        if L == 448:
            I = [0, 64, 192, 448]
        elif L == 512:
            I = [0, 512]
        elif L == 768:
            I = [0, 256, 768]
        elif L == 0:
            I = []
        else:
            print('Wrong descriptor length in main_descr.py')
            sys.exit(0)
        #
        fvec_miss = descr_miss[:, :, 0]  # MxL
        grad_miss = descr_miss[:, :, 1]  # MxL
        fvec_avail = descr_avail[:, :, 0]  # NxL
        grad_avail = descr_avail[:, :, 1]  # NxL
        #
        eps = 1e-10  # small regularization constant
        # PCC stuff
        if 'Pcc' in args.augment_method:
            for i in range(1, len(I)):
                fvec_miss[:, I[i - 1]:I[i]] -= torch.mean(
                    fvec_miss[:, I[i - 1]:I[i]], 1, keepdim=True)
                fvec_avail[:, I[i - 1]:I[i]] -= torch.mean(
                    fvec_avail[:, I[i - 1]:I[i]], 1, keepdim=True)
                grad_miss[:, I[i - 1]:I[i]] -= torch.mean(
                    grad_miss[:, I[i - 1]:I[i]], 1, keepdim=True)
                grad_avail[:, I[i - 1]:I[i]] -= torch.mean(
                    grad_avail[:, I[i - 1]:I[i]], 1, keepdim=True)
                fvec_miss[:, I[i - 1]:I[i]] /= (
                    torch.std(fvec_miss[:, I[i - 1]:I[i]], 1, keepdim=True) +
                    eps)
                fvec_avail[:, I[i - 1]:I[i]] /= (
                    torch.std(fvec_avail[:, I[i - 1]:I[i]], 1, keepdim=True) +
                    eps)
                grad_miss[:, I[i - 1]:I[i]] /= (
                    torch.std(grad_miss[:, I[i - 1]:I[i]], 1, keepdim=True) +
                    eps)
                grad_avail[:, I[i - 1]:I[i]] /= (
                    torch.std(grad_avail[:, I[i - 1]:I[i]], 1, keepdim=True) +
                    eps)
        #
        if args.augment_size >= M:
            kidx = range(M)
        else:
            print('K-center clustering for K/M:', args.augment_size, M)
            df = fvec_miss
            dg = grad_miss
            # copy
            dfA = df
            dgA = dg
            # k-centers:
            kidx = list()
            kidx.append(random.sample(range(M), 1)[0])  # initial center
            for b in range(1, args.augment_size):
                K = len(kidx)
                dfB = torch.index_select(
                    df, 0,
                    torch.tensor(kidx,
                                 dtype=torch.long).cuda())  # df[kidx] # KxL
                fDis = torch.mm(dfB, dfA.t())  # KxL * LxM = KxM
                dis = fDis
                if ('Grad' in args.augment_method):  # and (args.batch > 0):
                    dgB = torch.index_select(
                        dg, 0,
                        torch.tensor(
                            kidx, dtype=torch.long).cuda())  # dg[kidx] # KxL
                    gDis = torch.mm(dgB, dgA.t())  # KxL * LxM = KxM
                    dis += gDis
                #
                fCand = torch.max(dis, dim=0)[0]
                iCand = torch.argmin(fCand, dim=0)
                kidx.append(iCand.item())
        #
        fvec_k_miss = fvec_miss[kidx].clone()
        grad_k_miss = grad_miss[kidx].clone()
        #
        fA = fvec_avail
        gA = grad_avail
        #
        KM = len(kidx)
        C = KM // S  # number of chunks
        #
        topkindex_sim = torch.zeros((K, KM), dtype=torch.long)
        topkvalue_sim = torch.zeros((K, KM))
        print('fA/fB', fA.size(), fvec_k_miss.size())
        #
        for c in range(C + 1):
            idx = range(c * S, min((c + 1) * S, KM))
            fB = fvec_k_miss[idx]
            fSim = torch.mm(fA, fB.t())  # PxL * LxS = PxS
            sim = fSim
            if ('Grad' in args.augment_method):  # and (args.batch > 0):
                gB = grad_k_miss[idx]
                gSim = torch.mm(gA, gB.t())  # PxL * LxS = PxS
                sim += gSim
            #
            simCpu = sim.cpu()
            sim_val, sim_idx = torch.topk(simCpu, K, dim=0,
                                          largest=True)  # KxS
            topkindex_sim[:, idx] = sim_idx
            topkvalue_sim[:, idx] = sim_val
        #
        del descr_miss, descr_avail
        # topkindex_sim KxKM
        sortVal, sortIdx = torch.sort(topkvalue_sim, dim=1,
                                      descending=True)  # KxKM
        sortedindex_sim = torch.zeros((K, KM), dtype=torch.long)
        for k in range(K):
            sortedindex_sim[k] = topkindex_sim[k, sortIdx[k]]
        sortedindex_sim = sortedindex_sim.view(-1)
        topindex_sim = sortedindex_sim[0:args.augment_size]
        i = 1
        while torch.unique(topindex_sim).size(0) != args.augment_size:
            topindex_sim = torch.cat(
                [topindex_sim, sortedindex_sim[args.augment_size + i].view(1)])
            i = i + 1
        print('Search Iteration =', i, topindex_sim.size(0),
              torch.unique(topindex_sim).size(0))
        topindex_sim = torch.unique(topindex_sim)
        topindex_sim = topindex_sim.tolist()
    else:
        print('Wrong augmentation method or some descriptor files are missing')
        sys.exit()

    index_sim = [index_avail[i] for i in topindex_sim]
    set_sim = set(index_sim)
    augment_index_list = list(set_use | set_sim)
    assert len(
        augment_index_list
    ) == args.train_size + args.augment_size, ' Augmented train list length is wrong: {} vs. {} + {}'.format(
        len(augment_index_list), args.train_size, args.augment_size)
    # update train list
    augment_postfix = '{}batch_{}_size_{}_ir_{}_nr_{}_sub_{}_aug_{}_L_{}'.format(
        unsup_prefix, args.batch + 1, args.train_size + args.augment_size,
        args.imbalance_ratio, args.noisy_ratio, args.subtype_method,
        args.augment_method, args.descriptor_length)
    #
    if args.ensemble_size > 1:
        augment_checkpoint_file = '{}/best_{}_E_{}.pt'.format(
            model_folder, augment_postfix, args.ensemble_index)
        augment_train_list_file = '{}/{}/train_list_{}_E_{}.txt'.format(
            args.save_folder, args.run_folder, augment_postfix,
            args.ensemble_size)
        augment_index_list_file = '{}/{}/index_list_{}_E_{}.npy'.format(
            args.save_folder, args.run_folder, augment_postfix,
            args.ensemble_size)
    else:
        augment_checkpoint_file = '{}/best_{}.pt'.format(
            model_folder, augment_postfix)
        augment_train_list_file = '{}/{}/train_list_{}.txt'.format(
            args.save_folder, args.run_folder, augment_postfix)
        augment_index_list_file = '{}/{}/index_list_{}.npy'.format(
            args.save_folder, args.run_folder, augment_postfix)
    #
    np.save(augment_index_list_file, augment_index_list)
    augment_train_list = [full_train_list[i] for i in augment_index_list]
    with open(augment_train_list_file, "w") as f:
        f.write("\n".join(augment_train_list))

    # pipe for train dataset
    pipe = HybridTrainPipe(batch_size=args.batch_size,
                           num_threads=args.workers,
                           device_id=args.local_rank,
                           data_dir=train_dir,
                           file_list=augment_train_list_file,
                           crop=crop_size,
                           local_rank=args.local_rank,
                           world_size=args.world_size,
                           dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    [prec1, prec5] = validate(args, val_loader, model, criterion)
    save_checkpoint(
        {
            'epoch': 0,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'acc': prec1,
        }, augment_checkpoint_file)
    best_prec1 = prec1
    val_loader.reset()
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(args, train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        [prec1, prec5] = validate(args, val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            if prec1 > best_prec1:
                best_prec1 = prec1
                print('Saving best checkpoint at epoch {} with accuracy {}'.
                      format(epoch + 1, best_prec1))
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'acc': best_prec1,
                    }, augment_checkpoint_file)
        else:
            print('Local rank is not zero')

        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()

        print('##Top-1 {}, Top-5 {}'.format(prec1, prec5))
Пример #11
0
def main():
    global best_prec1, args

    if args.local_rank == 0 and not os.path.isdir(args.save_dir):
        mkdir_p(args.save_dir)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 requires cudnn backend to be enabled."
    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # create model
    if args.pretrained:
        if args.local_rank == 0:
            print("=> using pre-trained model '{}'".format(args.arch))
        elif args.arch.startswith('resnet'):
            model = resnets.__dict__[args.arch](pretrained=True)
        elif args.arch.startswith('mobilenet'):
            model = mobilenets.__dict__[args.arch](pretrained=True)
        else:
            raise NotImplementedError("Unkown network arch.")
    else:
        if args.local_rank == 0:
            print("=> creating {}".format(args.arch))
        # update args

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif 'resnext' in args.arch:
        model = models.__dict__[args.arch](
            baseWidth=args.base_width,
            cardinality=args.cardinality,
        )
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.local_rank == 0:
        if args.dataset.startswith('cifar'):
            H, W = 32, 32
        elif args.dataset.startswith('imagenet'):
            H, W = 224, 224
        else:
            raise NotImplementedError("Unknown dataset")
        flops, params = get_model_complexity_info(model, (224, 224),
                                                  as_strings=False,
                                                  print_per_layer_stat=False)
        print('=> FLOPs: {:.6f}G, Params: {:.6f}M'.format(
            flops / 1e9, params / 1e6))
        print('=> Params (double-check): %.6fM' %
              (sum(p.numel() for p in model.parameters()) / 1e6))

    if args.sync_bn:
        import apex
        if args.local_rank == 0:
            print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    if args.fp16:
        model = FP16Model(model)
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    if args.pretrained:
        model.load_state_dict(checkpoint['state_dict'])

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(args.batch_size * args.world_size) / 256

    if args.remove_norm_weight_decay:
        if args.local_rank == 0:
            print("=> ! Weight decay NOT applied to FeatNorm parameters ")
        norm_params = set()  #TODO: need to check this via experiments
        rest_params = set()
        for m in model.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                for param in m.parameters(False):
                    norm_params.add(param)
            else:
                for param in m.parameters(False):
                    rest_params.add(param)

        optimizer = torch.optim.SGD([{
            'params': list(norm_params),
            'weight_decay': 0.0
        }, {
            'params': list(rest_params)
        }],
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
    else:
        if args.local_rank == 0:
            print("=> ! Weight decay applied to FeatNorm parameters ")
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # define loss function (criterion) and optimizer
    criterion_train = nn.CrossEntropyLoss().cuda() if args.labelsmoothing_rate == 0.0 \
                        else LabelSmoothing(args.labelsmoothing_rate).cuda()
    criterion_val = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                if args.local_rank == 0:
                    print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                if args.local_rank == 0:
                    print("=> loaded checkpoint '{}' (epoch {})".format(
                        args.resume, checkpoint['epoch']))
            else:
                if args.local_rank == 0:
                    print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    # Data loading code
    if args.dataset == "cifar10":
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ])
        train_dataset = datasets.CIFAR10('./datasets',
                                         train=True,
                                         download=False,
                                         transform=train_transform)
        val_dataset = datasets.CIFAR10('./datasets',
                                       train=False,
                                       download=False)
    elif args.dataset == "cifar100":
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ])
        train_dataset = datasets.CIFAR100('./datasets',
                                          train=True,
                                          download=False,
                                          transform=train_transform)
        val_dataset = datasets.CIFAR100('./datasets',
                                        train=False,
                                        download=False)
    elif args.dataset == "imagenet":
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'valf')

        crop_size = args.crop_size  # 224
        val_size = args.crop_size + 32  # 256

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(
                    crop_size, interpolation=args.crop_interpolation),
                transforms.RandomHorizontalFlip(),
                # transforms.ToTensor(), Too slow
                # normalize,
            ]))
        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(val_size,
                                  interpolation=args.crop_interpolation),
                transforms.CenterCrop(crop_size),
            ]))

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             collate_fn=fast_collate)

    if args.evaluate:
        validate(val_loader, model, criterion_val)
        return

    scheduler = CosineAnnealingLR(
        optimizer.optimizer if args.fp16 else optimizer,
        args.epochs,
        len(train_loader),
        eta_min=0.,
        warmup=args.warmup_epochs)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion_train, optimizer, epoch,
              scheduler, args.warmup_epochs, args.mixup_rate,
              args.labelsmoothing_rate)
        #TODO: warmup_epochs, labelsmoothing_rate, mixup_rate, args.dataset, args.cropsize, args.crop_interpolation
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion_val)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save_dir)
Пример #12
0
def main(pargs):

    # this should be global
    global have_wandb

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    # set up logging
    pargs.logging_frequency = max([pargs.logging_frequency, 1])
    log_file = os.path.normpath(
        os.path.join(pargs.output_dir, "logs", pargs.run_tag + ".log"))
    logger = mll.mlperf_logger(log_file, "deepcam", "Umbrella Corp.")
    logger.log_start(key="init_start", sync=True)
    logger.log_event(key="cache_clear")

    #set seed
    seed = 333
    logger.log_event(key="seed", value=seed)

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")

    #visualize?
    visualize = (pargs.training_visualization_frequency >
                 0) or (pargs.validation_visualization_frequency > 0)

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        if visualize and not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

    # Setup WandB
    if not pargs.enable_wandb:
        have_wandb = False
    if have_wandb and (comm_rank == 0):
        # get wandb api token
        certfile = os.path.join(pargs.wandb_certdir, ".wandbirc")
        try:
            with open(certfile) as f:
                token = f.readlines()[0].replace("\n", "").split()
                wblogin = token[0]
                wbtoken = token[1]
        except IOError:
            print("Error, cannot open WandB certificate {}.".format(certfile))
            have_wandb = False

        if have_wandb:
            # log in: that call can be blocking, it should be quick
            sp.call(["wandb", "login", wbtoken])

            #init db and get config
            resume_flag = pargs.run_tag if pargs.resume_logging else False
            wandb.init(entity=wblogin,
                       project='deepcam',
                       name=pargs.run_tag,
                       id=pargs.run_tag,
                       resume=resume_flag)
            config = wandb.config

            #set general parameters
            config.root_dir = root_dir
            config.output_dir = pargs.output_dir
            config.max_epochs = pargs.max_epochs
            config.local_batch_size = pargs.local_batch_size
            config.num_workers = comm_size
            config.channels = pargs.channels
            config.optimizer = pargs.optimizer
            config.start_lr = pargs.start_lr
            config.adam_eps = pargs.adam_eps
            config.weight_decay = pargs.weight_decay
            config.model_prefix = pargs.model_prefix
            config.amp_opt_level = pargs.amp_opt_level
            config.loss_weight_pow = pargs.loss_weight_pow
            config.lr_warmup_steps = pargs.lr_warmup_steps
            config.lr_warmup_factor = pargs.lr_warmup_factor

            # lr schedule if applicable
            if pargs.lr_schedule:
                for key in pargs.lr_schedule:
                    config.update(
                        {"lr_schedule_" + key: pargs.lr_schedule[key]},
                        allow_val_change=True)

    # initial logging
    logger.log_event(key="global_batch_size",
                     value=(pargs.local_batch_size * comm_size))
    logger.log_event(key="optimizer", value=pargs.optimizer)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #restart from checkpoint if desired
    #if (comm_rank == 0) and (pargs.checkpoint):
    #load it on all ranks for now
    if pargs.checkpoint:
        checkpoint = torch.load(pargs.checkpoint, map_location=device)
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        net.load_state_dict(checkpoint['model'])
        if have_apex:
            amp.load_state_dict(checkpoint['amp'])
    else:
        start_step = 0
        start_epoch = 0

    #select scheduler
    if pargs.lr_schedule:
        scheduler_after = ph.get_lr_schedule(pargs.start_lr,
                                             pargs.lr_schedule,
                                             optimizer,
                                             last_step=start_step)

        if have_warmup_scheduler and (pargs.lr_warmup_steps > 0):
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=pargs.lr_warmup_factor,
                total_epoch=pargs.lr_warmup_steps,
                after_scheduler=scheduler_after)
        else:
            scheduler = scheduler_after

    #broadcast model and optimizer state
    steptens = torch.tensor(np.array([start_step, start_epoch]),
                            requires_grad=False).to(device)
    dist.broadcast(steptens, src=0)

    ##broadcast model and optimizer state
    #hvd.broadcast_parameters(net.state_dict(), root_rank = 0)
    #hvd.broadcast_optimizer_state(optimizer, root_rank = 0)

    #unpack the bcasted tensor
    start_step = steptens.cpu().numpy()[0]
    start_epoch = steptens.cpu().numpy()[1]

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    """train_set = cam.CamDataset(train_dir,
                               statsfile = os.path.join(root_dir, 'stats.h5'),
                               channels = pargs.channels,
                               allow_uneven_distribution = False,
                               shuffle = True, 
                               preprocess = True,
                               comm_size = comm_size,
                               comm_rank = comm_rank)"""
    train_set = HDMLPCam.HDMLPCam(
        train_dir, os.path.join(root_dir, 'stats.h5'), pargs.channels, 2,
        pargs.max_epochs, False,
        "/Users/roman/PyCharmProjects/hdmlp/libhdmlp/data/hdmlp.cfg")
    train_loader = hdmlp.lib.torch.HDMLPDataLoader(train_set)
    """train_loader = DataLoader(train_set,
                              pargs.local_batch_size,
                              num_workers = min([pargs.max_inter_threads, pargs.local_batch_size]),
                              pin_memory = True,
                              drop_last = True)"""

    # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps
    validation_dir = os.path.join(root_dir, "validation")
    validation_set = cam.CamDataset(validation_dir,
                                    statsfile=os.path.join(
                                        root_dir, 'stats.h5'),
                                    channels=pargs.channels,
                                    allow_uneven_distribution=True,
                                    shuffle=(pargs.max_validation_steps
                                             is not None),
                                    preprocess=True,
                                    comm_size=comm_size,
                                    comm_rank=comm_rank)
    # use batch size = 1 here to make sure that we do not drop a sample
    validation_loader = DataLoader(
        validation_set,
        1,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        pin_memory=True,
        drop_last=True)

    # log size of datasets
    logger.log_event(key="train_samples", value=train_set.global_size)
    if pargs.max_validation_steps is not None:
        val_size = min([
            validation_set.global_size,
            pargs.max_validation_steps * pargs.local_batch_size * comm_size
        ])
    else:
        val_size = validation_set.global_size
    logger.log_event(key="eval_samples", value=val_size)

    # do sanity check
    if pargs.max_validation_steps is not None:
        logger.log_event(key="invalid_submission")

    #for visualization
    if visualize:
        viz = vizc.CamVisualizer()

    # Train network
    if have_wandb and (comm_rank == 0):
        wandb.watch(net)

    step = start_step
    epoch = start_epoch
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    net.train()

    # start trining
    logger.log_end(key="init_stop", sync=True)
    logger.log_start(key="run_start", sync=True)

    # training loop
    while True:

        # start epoch
        logger.log_start(key="epoch_start",
                         metadata={
                             'epoch_num': epoch + 1,
                             'step_num': step
                         },
                         sync=True)

        # epoch loop
        for inputs, label in train_loader:
            filename = "n/a"

            # send to device
            inputs = inputs.to(device)
            label = label.to(device)

            # forward pass
            outputs = net.forward(inputs)

            # Compute loss and average across nodes
            loss = criterion(outputs,
                             label,
                             weight=class_weights,
                             fpw_1=fpw_1,
                             fpw_2=fpw_2)

            # Backprop
            optimizer.zero_grad()
            if have_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            # step counter
            step += 1

            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #visualize if requested
            if (step % pargs.training_visualization_frequency
                    == 0) and (comm_rank == 0):
                # Compute predictions
                predictions = torch.max(outputs, 1)[1]

                # extract sample id and data tensors
                sample_idx = np.random.randint(low=0, high=label.shape[0])
                plot_input = inputs.detach()[sample_idx, 0, ...].cpu().numpy()
                plot_prediction = predictions.detach()[sample_idx,
                                                       ...].cpu().numpy()
                plot_label = label.detach()[sample_idx, ...].cpu().numpy()

                # create filenames
                outputfile = os.path.basename(filename[sample_idx]).replace(
                    "data-", "training-").replace(".h5", ".png")
                outputfile = os.path.join(plot_dir, outputfile)

                # plot
                viz.plot(filename[sample_idx], outputfile, plot_input,
                         plot_prediction, plot_label)

                #log if requested
                if have_wandb:
                    img = Image.open(outputfile)
                    wandb.log(
                        {
                            "train_examples": [
                                wandb.Image(
                                    img, caption="Prediction vs. Ground Truth")
                            ]
                        },
                        step=step)

            #log if requested
            if (step % pargs.logging_frequency == 0):

                # allreduce for loss
                loss_avg = loss.detach()
                dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)
                loss_avg_train = loss_avg.item() / float(comm_size)

                # Compute score
                predictions = torch.max(outputs, 1)[1]
                iou = utils.compute_score(predictions,
                                          label,
                                          device_id=device,
                                          num_classes=3)
                iou_avg = iou.detach()
                dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)
                iou_avg_train = iou_avg.item() / float(comm_size)

                logger.log_event(key="learning_rate",
                                 value=current_lr,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="train_accuracy",
                                 value=iou_avg_train,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="train_loss",
                                 value=loss_avg_train,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })

                if have_wandb and (comm_rank == 0):
                    wandb.log(
                        {"train_loss": loss_avg.item() / float(comm_size)},
                        step=step)
                    wandb.log(
                        {"train_accuracy": iou_avg.item() / float(comm_size)},
                        step=step)
                    wandb.log({"learning_rate": current_lr}, step=step)

            # validation step if desired
            if (step % pargs.validation_frequency == 0):

                logger.log_start(key="eval_start",
                                 metadata={'epoch_num': epoch + 1})

                #eval
                net.eval()

                count_sum_val = torch.Tensor([0.]).to(device)
                loss_sum_val = torch.Tensor([0.]).to(device)
                iou_sum_val = torch.Tensor([0.]).to(device)

                # disable gradients
                with torch.no_grad():

                    # iterate over validation sample
                    step_val = 0
                    # only print once per eval at most
                    visualized = False
                    for inputs_val, label_val, filename_val in validation_loader:

                        #send to device
                        inputs_val = inputs_val.to(device)
                        label_val = label_val.to(device)

                        # forward pass
                        outputs_val = net.forward(inputs_val)

                        # Compute loss and average across nodes
                        loss_val = criterion(outputs_val,
                                             label_val,
                                             weight=class_weights,
                                             fpw_1=fpw_1,
                                             fpw_2=fpw_2)
                        loss_sum_val += loss_val

                        #increase counter
                        count_sum_val += 1.

                        # Compute score
                        predictions_val = torch.max(outputs_val, 1)[1]
                        iou_val = utils.compute_score(predictions_val,
                                                      label_val,
                                                      device_id=device,
                                                      num_classes=3)
                        iou_sum_val += iou_val

                        # Visualize
                        if (step_val % pargs.validation_visualization_frequency
                                == 0) and (not visualized) and (comm_rank
                                                                == 0):
                            #extract sample id and data tensors
                            sample_idx = np.random.randint(
                                low=0, high=label_val.shape[0])
                            plot_input = inputs_val.detach()[
                                sample_idx, 0, ...].cpu().numpy()
                            plot_prediction = predictions_val.detach()[
                                sample_idx, ...].cpu().numpy()
                            plot_label = label_val.detach()[sample_idx,
                                                            ...].cpu().numpy()

                            #create filenames
                            outputfile = os.path.basename(
                                filename[sample_idx]).replace(
                                    "data-",
                                    "validation-").replace(".h5", ".png")
                            outputfile = os.path.join(plot_dir, outputfile)

                            #plot
                            viz.plot(filename[sample_idx], outputfile,
                                     plot_input, plot_prediction, plot_label)
                            visualized = True

                            #log if requested
                            if have_wandb:
                                img = Image.open(outputfile)
                                wandb.log(
                                    {
                                        "eval_examples": [
                                            wandb.Image(
                                                img,
                                                caption=
                                                "Prediction vs. Ground Truth")
                                        ]
                                    },
                                    step=step)

                        #increase eval step counter
                        step_val += 1

                        if (pargs.max_validation_steps is not None
                            ) and step_val > pargs.max_validation_steps:
                            break

                # average the validation loss
                dist.reduce(count_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(loss_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(iou_sum_val, dst=0, op=dist.ReduceOp.SUM)
                loss_avg_val = loss_sum_val.item() / count_sum_val.item()
                iou_avg_val = iou_sum_val.item() / count_sum_val.item()

                # print results
                logger.log_event(key="eval_accuracy",
                                 value=iou_avg_val,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="eval_loss",
                                 value=loss_avg_val,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })

                # log in wandb
                if have_wandb and (comm_rank == 0):
                    wandb.log({"eval_loss": loss_avg_val}, step=step)
                    wandb.log({"eval_accuracy": iou_avg_val}, step=step)

                if (iou_avg_val >= pargs.target_iou):
                    logger.log_event(key="target_accuracy_reached",
                                     value=pargs.target_iou,
                                     metadata={
                                         'epoch_num': epoch + 1,
                                         'step_num': step
                                     })

                # set to train
                net.train()

                logger.log_end(key="eval_stop",
                               metadata={'epoch_num': epoch + 1})

            #save model if desired
            if (pargs.save_frequency > 0) and (step % pargs.save_frequency
                                               == 0):
                logger.log_start(key="save_start",
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 },
                                 sync=True)
                if comm_rank == 0:
                    checkpoint = {
                        'step': step,
                        'epoch': epoch,
                        'model': net.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    if have_apex:
                        checkpoint['amp'] = amp.state_dict()
                    torch.save(
                        checkpoint,
                        os.path.join(
                            output_dir, pargs.model_prefix + "_step_" +
                            str(step) + ".cpt"))
                logger.log_end(key="save_stop",
                               metadata={
                                   'epoch_num': epoch + 1,
                                   'step_num': step
                               },
                               sync=True)

        # log the epoch
        logger.log_end(key="epoch_stop",
                       metadata={
                           'epoch_num': epoch + 1,
                           'step_num': step
                       },
                       sync=True)
        epoch += 1

        # are we done?
        if epoch >= pargs.max_epochs:
            break

    # run done
    logger.log_end(key="run_stop", sync=True, metadata={'status': 'success'})
Пример #13
0
def main():
    def evaluate(dataloader, export=None):
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        logits_list = []
        iter_idx = 0
        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            if export is not None:
                logits_list.append(logits)
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1
            if (iter_idx + 1) % 1000 == 0 and export is not None:
                torch.save((iter_idx, logits_list), export)
            iter_idx += 1
        if export is not None:
            torch.save(logits_list, export)

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss
        }
        return result

    local_rank = -1
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, required=True)
    args, _ = parser.parse_known_args()
    options = argconf.options_from_json("confs/options.json")
    config = argconf.config_from_json(args.config)
    args = edict(argconf.parse_args(options, config))
    print(f"Using config: {args}")
    bv_utils.set_seed(args.seed)
    args.do_train = args.do_train and not args.do_test_only

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "sst2": SST2Processor,
        "imdb": IMDBSentenceProcessor,
        "qqp": QuoraProcessor,
        "sts": STSProcessor,
        "raw_pair": RawPairProcessor
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.workspace) and os.listdir(
            args.workspace) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.workspace))
    if not os.path.exists(args.workspace):
        os.makedirs(args.workspace)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = args.n_labels
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.model_file,
                                              do_lower_case=args.uncased)

    num_train_optimization_steps = None
    train_examples = processor.get_train_examples(args.data_dir)
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    # Prepare model
    cache_dir = os.path.join(PYTORCH_PRETRAINED_BERT_CACHE,
                             'distributed_{}'.format(local_rank))
    model = BertForSequenceClassification.from_pretrained(
        args.model_file, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    train_features = convert_examples_to_features(train_examples, label_list,
                                                  args.max_seq_length,
                                                  tokenizer)
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids)
    if local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    output_model_file = os.path.join(args.workspace, WEIGHTS_NAME)
    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.workspace, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    elif args.do_test_only:
        convert = bv_utils.convert_single_to_dp if isinstance(
            model, torch.nn.DataParallel) else bv_utils.convert_dp_to_single
        model.load_state_dict(convert(torch.load(output_model_file)))
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.model_file, num_labels=num_labels)
    model.to(device)

    if args.export:
        model.eval()
        train_dataloader = DataLoader(train_data,
                                      batch_size=args.eval_batch_size,
                                      shuffle=False)
        with torch.no_grad():
            evaluate(train_dataloader, export=args.export)

    if args.visualize:
        model.eval()
        train_dataloader = DataLoader(train_data,
                                      batch_size=args.eval_batch_size,
                                      shuffle=False)
        with open(os.path.join(args.workspace, "viz_results.csv"), "w") as f:
            writer = None
            dir_a = bv_viz.choose_random_dir(list(model.parameters()))
            dir_b = bv_viz.choose_random_dir(list(model.parameters()))
            torch.save(dir_a, os.path.join(args.workspace, "viz_dir_a.pt"))
            torch.save(dir_b, os.path.join(args.workspace, "viz_dir_b.pt"))
            for a, b in bv_viz.contour_2d(model, dir_a, dir_b):
                result = evaluate(train_dataloader)
                result["a"] = a
                result["b"] = b
                if writer is None:
                    writer = csv.DictWriter(f, fieldnames=result.keys())
                    writer.writeheader()
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                writer.writerow(result)

    if args.do_eval and (local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_test_examples(
            args.data_dir
        ) if args.do_test_only else processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        result = evaluate(eval_dataloader)

        output_eval_file = os.path.join(args.workspace, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_eval_on_train",
                        action='store_true',
                        help="Whether to run eval on the train set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run test and create submission.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--seed",
                        default=None,
                        type=int,
                        help="Seed for randomized elements in the training")
    parser.add_argument("--eval_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    ## Our arguements
    parser.add_argument(
        "--mode",
        choices=[
            "none", "distill", "smoothed_distill", "smoothed_distill_annealed",
            "label_smoothing", "theta_smoothed_distill", "reweight_baseline",
            "smoothed_reweight_baseline", "permute_smoothed_distill",
            "bias_product_baseline", "learned_mixin_baseline",
            "reweight_by_teacher", "reweight_by_teacher_annealed",
            "bias_product_by_teacher", "bias_product_by_teacher_annealed",
            "focal_loss"
        ])
    parser.add_argument("--penalty",
                        type=float,
                        default=0.03,
                        help="Penalty weight for the learn_mixin model")
    parser.add_argument("--focal_loss_gamma", type=float, default=1.0)
    parser.add_argument("--n_processes",
                        type=int,
                        default=4,
                        help="Processes to use for pre-processing")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--debug_num", type=int, default=2000)
    parser.add_argument(
        "--sorted",
        action="store_true",
        help='Sort the data so most batches have the same input length,'
        ' makes things about 2x faster. Our experiments did not actually'
        ' use this in the end (not sure if it makes a difference) so '
        'its off by default.')
    parser.add_argument("--which_bias",
                        choices=["hans", "hypo", "hans_json", "mix", "dam"],
                        required=True)
    parser.add_argument("--custom_teacher", default=None)
    parser.add_argument("--custom_bias", default=None)
    parser.add_argument("--theta",
                        type=float,
                        default=0.1,
                        help="for theta smoothed distillation loss")
    parser.add_argument("--add_bias_on_eval", action="store_true")

    args = parser.parse_args()

    utils.add_stdout_logger()

    if args.mode == "none":
        loss_fn = clf_distill_loss_functions.Plain()
    elif args.mode == "distill":
        loss_fn = clf_distill_loss_functions.DistillLoss()
    elif args.mode == "smoothed_distill":
        loss_fn = clf_distill_loss_functions.SmoothedDistillLoss()
    elif args.mode == "smoothed_distill_annealed":
        loss_fn = clf_distill_loss_functions.SmoothedDistillLossAnnealed()
    elif args.mode == "theta_smoothed_distill":
        loss_fn = clf_distill_loss_functions.ThetaSmoothedDistillLoss(
            args.theta)
    elif args.mode == "label_smoothing":
        loss_fn = clf_distill_loss_functions.LabelSmoothing(3)
    elif args.mode == "reweight_baseline":
        loss_fn = clf_distill_loss_functions.ReweightBaseline()
    elif args.mode == "permute_smoothed_distill":
        loss_fn = clf_distill_loss_functions.PermuteSmoothedDistillLoss()
    elif args.mode == "smoothed_reweight_baseline":
        loss_fn = clf_distill_loss_functions.SmoothedReweightLoss()
    elif args.mode == "bias_product_baseline":
        loss_fn = clf_distill_loss_functions.BiasProductBaseline()
    elif args.mode == "learned_mixin_baseline":
        loss_fn = clf_distill_loss_functions.LearnedMixinBaseline(args.penalty)
    elif args.mode == "reweight_by_teacher":
        loss_fn = clf_distill_loss_functions.ReweightByTeacher()
    elif args.mode == "reweight_by_teacher_annealed":
        loss_fn = clf_distill_loss_functions.ReweightByTeacherAnnealed()
    elif args.mode == "bias_product_by_teacher":
        loss_fn = clf_distill_loss_functions.BiasProductByTeacher()
    elif args.mode == "bias_product_by_teacher_annealed":
        loss_fn = clf_distill_loss_functions.BiasProductByTeacherAnnealed()
    elif args.mode == "focal_loss":
        loss_fn = clf_distill_loss_functions.FocalLoss(
            gamma=args.focal_loss_gamma)
    else:
        raise RuntimeError()

    output_dir = args.output_dir

    if args.do_train:
        if exists(output_dir):
            if len(os.listdir(output_dir)) > 0:
                logging.warning("Output dir exists and is non-empty")
        else:
            os.makedirs(output_dir)

    print("Saving model to %s" % output_dir)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logging.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_test:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(output_dir) and os.listdir(output_dir) and args.do_train:
        logging.warning(
            "Output directory ({}) already exists and is not empty.".format(
                output_dir))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Its way ot easy to forget if this is being set by a command line flag
    if "-uncased" in args.bert_model:
        do_lower_case = True
    elif "-cased" in args.bert_model:
        do_lower_case = False
    else:
        raise NotImplementedError(args.bert_model)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=do_lower_case)

    num_train_optimization_steps = None
    train_examples = None
    if args.do_train:
        train_examples = load_mnli(True,
                                   args.debug_num if args.debug else None)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )
        loss_fn.num_train_optimization_steps = int(
            num_train_optimization_steps)
        loss_fn.num_epochs = int(args.num_train_epochs)

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))

    model = BertDistill.from_pretrained(args.bert_model,
                                        cache_dir=cache_dir,
                                        num_labels=3,
                                        loss_fn=loss_fn)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    if args.do_train:
        train_features: List[InputFeatures] = convert_examples_to_features(
            train_examples, args.max_seq_length, tokenizer, args.n_processes)

        if args.which_bias == "mix":
            hypo_bias_map = load_bias("hypo")
            hans_bias_map = load_bias("hans")
            bias_map = {}

            def compute_entropy(probs, base=3):
                return -(probs * (np.log(probs) / np.log(base))).sum()

            for key in hypo_bias_map.keys():
                hypo_ent = compute_entropy(np.exp(hypo_bias_map[key]))
                hans_ent = compute_entropy(np.exp(hans_bias_map[key]))
                if hypo_ent < hans_ent:
                    bias_map[key] = hypo_bias_map[key]
                else:
                    bias_map[key] = hans_bias_map[key]
        else:
            bias_map = load_bias(args.which_bias, custom_path=args.custom_bias)

        for fe in train_features:
            fe.bias = bias_map[fe.example_id].astype(np.float32)
        teacher_probs_map = load_teacher_probs(args.custom_teacher)
        for fe in train_features:
            fe.teacher_probs = np.array(
                teacher_probs_map[fe.example_id]).astype(np.float32)

        example_map = {}
        for ex in train_examples:
            example_map[ex.id] = ex

        logging.info("***** Running training *****")
        logging.info("  Num examples = %d", len(train_examples))
        logging.info("  Batch size = %d", args.train_batch_size)
        logging.info("  Num steps = %d", num_train_optimization_steps)

        train_dataloader = build_train_dataloader(train_features,
                                                  args.train_batch_size,
                                                  args.seed, args.sorted)

        model.train()
        loss_ema = 0
        total_steps = 0
        decay = 0.99

        for _ in trange(int(args.num_train_epochs), desc="Epoch", ncols=100):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            pbar = tqdm(train_dataloader, desc="loss", ncols=100)
            for step, batch in enumerate(pbar):
                batch = tuple(t.to(device) for t in batch)
                if bias_map is not None:
                    example_ids, input_ids, input_mask, segment_ids, label_ids, bias, teacher_probs = batch
                else:
                    bias = None
                    example_ids, input_ids, input_mask, segment_ids, label_ids = batch

                logits, loss = model(input_ids, segment_ids, input_mask,
                                     label_ids, bias, teacher_probs)

                total_steps += 1
                loss_ema = loss_ema * decay + loss.cpu().detach().numpy() * (
                    1 - decay)
                descript = "loss=%.4f" % (loss_ema / (1 - decay**total_steps))
                pbar.set_description(descript, refresh=False)

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Record the args as well
        arg_dict = {}
        for arg in vars(args):
            arg_dict[arg] = getattr(args, arg)
        with open(join(output_dir, "args.json"), 'w') as out_fh:
            json.dump(arg_dict, out_fh)

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertDistill(config, num_labels=3, loss_fn=loss_fn)
        model.load_state_dict(torch.load(output_model_file))
    else:
        output_config_file = os.path.join(output_dir, CONFIG_NAME)
        config = BertConfig.from_json_file(output_config_file)
        output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
        model = BertDistill(config, num_labels=3, loss_fn=loss_fn)
        model.load_state_dict(torch.load(output_model_file))

    model.to(device)

    if not args.do_eval and not args.do_test:
        return
    if not (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        return

    model.eval()

    if args.do_eval:
        eval_datasets = [("mnli_dev_m", load_mnli(False)),
                         ("mnli_dev_mm",
                          load_mnli(False, custom_path="dev_mismatched.tsv"))]
        # eval_datasets += load_easy_hard(prefix="overlap_", no_mismatched=True)
        eval_datasets += load_easy_hard()
        eval_datasets += [("hans", load_hans())]
        eval_datasets += load_hans_subsets()

        # stress test
        # eval_datasets += [("negation_m", load_jsonl("multinli_0.9_negation_matched.jsonl",
        #                                             "../dataset/StressTests/Negation"))]
        # eval_datasets += [("negation_mm", load_jsonl("multinli_0.9_negation_mismatched.jsonl",
        #                                              "../dataset/StressTests/Negation"))]
        # eval_datasets += [("overlap_m", load_jsonl("multinli_0.9_taut2_matched.jsonl",
        #                                             "../dataset/StressTests/Word_Overlap"))]
        # eval_datasets += [("overlap_mm", load_jsonl("multinli_0.9_taut2_mismatched.jsonl",
        #                                              "../dataset/StressTests/Word_Overlap"))]
        # eval_datasets += [("length_m", load_jsonl("multinli_0.9_length_mismatch_matched.jsonl",
        #                                             "../dataset/StressTests/Length_Mismatch"))]
        # eval_datasets += [("length_mm", load_jsonl("multinli_0.9_length_mismatch_mismatched.jsonl",
        #                                              "../dataset/StressTests/Length_Mismatch"))]

        # eval_datasets = [("rte", load_jsonl("eval_rte.jsonl",
        #                                     "../dataset/mnli_eval_suite"))]
        # eval_datasets += [("rte_glue", load_jsonl("eval_glue_rte.jsonl",
        #                                          "../dataset/mnli_eval_suite"))]
        # eval_datasets += [("sick", load_jsonl("eval_sick.jsonl",
        #                                       "../dataset/mnli_eval_suite"))]
        # eval_datasets += [("diagnostic", load_jsonl("diagnostic-full.jsonl",
        #                                             "../dataset/mnli_eval_suite"))]
        # eval_datasets += [("scitail", load_jsonl("scitail_1.0_test.txt",
        #                                           "../dataset/scitail/snli_format"))]

        # todo delete
        if args.do_eval_on_train:
            eval_datasets = [("mnli_train", load_mnli(True))]
    else:
        eval_datasets = []

    if args.do_test:
        test_datasets = load_all_test_jsonl()
        eval_datasets += test_datasets
        subm_paths = [
            "../submission/{}.csv".format(x[0]) for x in test_datasets
        ]

    for ix, (name, eval_examples) in enumerate(eval_datasets):
        logging.info("***** Running evaluation on %s *****" % name)
        logging.info("  Num examples = %d", len(eval_examples))
        logging.info("  Batch size = %d", args.eval_batch_size)
        eval_features = convert_examples_to_features(eval_examples,
                                                     args.max_seq_length,
                                                     tokenizer)
        eval_features.sort(key=lambda x: len(x.input_ids))
        all_label_ids = np.array([x.label_id for x in eval_features])
        eval_dataloader = build_eval_dataloader(eval_features,
                                                args.eval_batch_size)

        eval_loss = 0
        nb_eval_steps = 0
        probs = []
        test_subm_ids = []

        for example_ids, input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating", ncols=100):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask)

            # create eval loss and other metric required by the task
            loss_fct = CrossEntropyLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, 3), label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            probs.append(
                torch.nn.functional.softmax(logits, 1).detach().cpu().numpy())
            test_subm_ids.append(example_ids.cpu().numpy())

        probs = np.concatenate(probs, 0)
        test_subm_ids = np.concatenate(test_subm_ids, 0)
        eval_loss = eval_loss / nb_eval_steps

        if "hans" in name:
            # take max of non-entailment rather than taking their sum
            probs[:, 0] = probs[:, [0, 2]].max(axis=1)
            # probs[:, 0] = probs[:, 0] + probs[:, 2]
            probs = probs[:, :2]

        preds = np.argmax(probs, axis=1)

        result = {"acc": simple_accuracy(preds, all_label_ids)}
        result["loss"] = eval_loss

        conf_plot_file = os.path.join(output_dir,
                                      "eval_%s_confidence.png" % name)
        ECE, bins_acc, bins_conf, bins_num = visualize_predictions(
            probs, all_label_ids, conf_plot_file=conf_plot_file)
        result["ECE"] = ECE
        result["bins_acc"] = bins_acc
        result["bins_conf"] = bins_conf
        result["bins_num"] = bins_num

        output_eval_file = os.path.join(output_dir,
                                        "eval_%s_results.txt" % name)
        output_all_eval_file = os.path.join(output_dir, "eval_all_results.txt")
        with open(output_eval_file,
                  "w") as writer, open(output_all_eval_file,
                                       "a") as all_writer:
            logging.info("***** Eval results *****")
            all_writer.write("eval results on %s:\n" % name)
            for key in sorted(result.keys()):
                logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
                all_writer.write("%s = %s\n" % (key, str(result[key])))

        output_answer_file = os.path.join(output_dir,
                                          "eval_%s_answers.json" % name)
        answers = {
            ex.example_id: [float(x) for x in p]
            for ex, p in zip(eval_features, probs)
        }
        with open(output_answer_file, "w") as f:
            json.dump(answers, f)

        # prepare submission file
        if args.do_test and ix >= len(eval_datasets) - len(test_datasets):
            with open(subm_paths.pop(0), "w") as subm_f:
                subm_f.write("pairID,gold_label\n")
                for sub_id, pred_label_id in zip(test_subm_ids, preds):
                    subm_f.write("{},{}\n".format(
                        str(sub_id), REV_NLI_LABEL_MAP[pred_label_id]))
Пример #15
0
def main():
    start_full = time.time()
    global best_prec1, args

    time_stat = []
    start = time.time()

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.amp and args.fp16:
        print("Please use only one of the --fp16/--amp flags")
        exit(1)

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    print("=> creating model AudioNet")
    model = audio_model.AudioNet(num_classes=args.classes)
    model = model.cuda()

    if args.fp16:
        model = network_to_half(model)

    # We will use the same optimization technique used in the paper, an Adam
    # optimizer with weight decay set to 0.0001. At first, we will train with
    # a learning rate of 0.01, but we will use a ``scheduler`` to decrease it
    # to 0.001 during training.
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse
    if args.amp:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=args.keep_batchnorm_fp32,
            loss_scale=args.loss_scale,
            min_loss_scale=1.0)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # shared param/delay all reduce turns off bucketing in DDP, for lower latency runs this can improve perf
        # for the older version of APEX please use shared_param, for newer one it is delay_allreduce
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    #criterion = nn.CrossEntropyLoss().cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir = args.data[1]

    pipe = AudioTrainPipe(batch_size=args.batch_size,
                          num_threads=args.workers,
                          device_id=args.local_rank,
                          data_dir=traindir,
                          dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIAudioClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = AudioValPipe(batch_size=args.batch_size,
                        num_threads=args.workers,
                        device_id=args.local_rank,
                        data_dir=valdir,
                        dali_cpu=args.dali_cpu)
    pipe.build()
    val_loader = DALIAudioClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    if args.evaluate:
        validate(val_loader, model)
        return

    total_time = AverageMeter()
    dur_setup = time.time() - start
    time_stat.append(dur_setup)
    print("Batch size for GPU {} is {}, workers={}".format(
        args.gpu, args.batch_size, args.workers))

    for epoch in range(args.start_epoch, args.epochs):

        # log timing
        start_ep = time.time()

        # train for one epoch

        avg_train_time = train(train_loader, model, optimizer, epoch)
        total_time.update(avg_train_time)
        if args.prof:
            break
        # evaluate on validation set
        [prec1, prec5] = validate(val_loader, model)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(
                          prec1, prec5,
                          args.total_batch_size / total_time.avg))

        scheduler.step()
        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()

        dur_ep = time.time() - start_ep
        os.system("du -sh /dev/shm/cache/")
        time_stat.append(dur_ep)
        print("Epoch duration={}".format(dur_ep))

    if args.local_rank == 0:
        for i in time_stat:
            print("Time_stat : {}".format(i))

        for i in range(0, len(data_time_list)):
            print("Data time : {}\t Compute time : {}".format(
                data_time_list[i], compute_time_list[i]))

    dur_full = time.time() - start_full
    if args.local_rank == 0:
        print("Total time for all epochs = {}".format(dur_full))
Пример #16
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .csv files (or other data files) for the task."
    )

    parser.add_argument(
        "--output_sentvec_file",
        default=None,
        type=str,
        required=True,
        help="The output file of extracted embedding files of sentences.")

    parser.add_argument(
        "--data_split_to_extract",
        default=None,
        type=str,
        required=True,
        help="The output file of extracted embedding files of sentences.")

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    parser.add_argument("--epoch_id",
                        default=0,
                        type=int,
                        help="Epoch id to extract.")

    parser.add_argument(
        "--save_model_name",
        default="model",
        type=str,
        required=True,
        help=
        "The output model name where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument("--with_dev",
                        action='store_true',
                        help="Whether to run training with dev.")

    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")

    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run test on the test set.")

    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")

    parser.add_argument("--layer_id",
                        default=-1,
                        type=int,
                        help="Output Layer Id")

    parser.add_argument("--mlp_hidden_dim",
                        default=64,
                        type=int,
                        help="mlp_hidden_dim.")

    parser.add_argument("--mlp_dropout",
                        default=0.1,
                        type=float,
                        help="hidden drop out")

    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument('--patience',
                        type=int,
                        default=5,
                        help="early stop epoch nums on dev")

    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()
    print("torch.cuda.is_available()", torch.cuda.is_available())
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_test:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        print(
            "WARNING: Output directory ({}) already exists and is not empty.".
            format(args.output_dir))
        # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_csqa_examples(
            os.path.join(args.data_dir, 'train_rand_split.jsonl'))
        dev_examples = read_csqa_examples(
            os.path.join(args.data_dir, 'dev_rand_split.jsonl'))
        print(len(train_examples))
        if args.with_dev:
            train_examples += dev_examples
            print(len(train_examples))

        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForMultipleChoiceExtraction.from_pretrained(
        args.bert_model,
        cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE,
                               'distributed_{}'.format(args.local_rank)),
        num_choices=5,
        mlp_hidden_dim=args.mlp_hidden_dim,
        mlp_dropout=args.mlp_dropout)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0

    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        # Load a trained model and config that you have fine-tuned
        output_model_file = os.path.join(
            args.output_dir,
            args.save_model_name + ".bin.%d" % (args.epoch_id))
        output_config_file = os.path.join(args.output_dir,
                                          args.save_model_name + ".config")
        config = BertConfig(output_config_file)
        model = BertForMultipleChoiceExtraction(
            config,
            num_choices=5,
            mlp_hidden_dim=args.mlp_hidden_dim,
            mlp_dropout=args.mlp_dropout)
        model.load_state_dict(torch.load(output_model_file))
        model.to(device)
        # to extract dev_rand_split.jsonl 'dev_rand_split.jsonl'
        eval_examples = read_csqa_examples(
            os.path.join(args.data_dir, args.data_split_to_extract))

        eval_features = convert_examples_to_features(eval_examples, tokenizer,
                                                     args.max_seq_length, True)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(eval_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(eval_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in eval_features],
                                 dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        pooled_sent_vecs = []
        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Iteration"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                # tmp_eval_loss, pooled_output = model(input_ids, segment_ids, input_mask, label_ids)
                logits, pooled_output = model(input_ids,
                                              segment_ids,
                                              input_mask,
                                              layer_id=args.layer_id)
            pooled_sent_vecs.append(pooled_output)
            # print(pooled_output.size())
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()

            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        pooled_sent_vecs = torch.cat(pooled_sent_vecs, dim=0)
        print(pooled_sent_vecs.size())
        output_numpy = pooled_sent_vecs.to('cpu').numpy()
        print(output_numpy.shape)

        np.save(args.output_sentvec_file + ".%d" % (args.layer_id),
                output_numpy)

        eval_accuracy = eval_accuracy / nb_eval_examples

        result = {'eval_accuracy': eval_accuracy}
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
Пример #17
0
def main():
    cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)

    # handle distributed traininc
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.gpu = 0
    args.world_size = 1
    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    global base_learning_rate
    global best_acc1
    base_learning_rate = args.base_lr * float(
        args.batch_size * args.world_size) / 256.

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()

    # define loss function (criterion) and optimizer
    cel = nn.CrossEntropyLoss()
    criterion = lambda pred, target, lam: (-F.log_softmax(
        pred, dim=1) * torch.zeros(pred.size()).cuda().scatter_(
            1, target.data.view(-1, 1), lam.view(-1, 1))).sum(dim=1).mean()
    parameters_bias = [
        p[1] for p in model.named_parameters() if 'bias' in p[0]
    ]
    parameters_scale = [
        p[1] for p in model.named_parameters() if 'scale' in p[0]
    ]
    parameters_others = [
        p[1] for p in model.named_parameters()
        if not ('bias' in p[0] or 'scale' in p[0])
    ]
    optimizer = torch.optim.SGD([{
        'params': parameters_bias,
        'lr': args.base_lr / 10.
    }, {
        'params': parameters_scale,
        'lr': args.base_lr / 10.
    }, {
        'params': parameters_others
    }],
                                lr=base_learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    #torch.cuda.set_device(args.gpu)
    model = DDP(model, delay_allreduce=True)
    #model = torch.nn.DataParallel(model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    sgdr = CosineAnnealingLR(optimizer, args.epochs, eta_min=0, last_epoch=-1)
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, cel, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if args.local_rank == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default="",
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--model_file",
        default="",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese or any pretrained model directory with model.bin and config file"
    )
    parser.add_argument(
        "--bert_model",
        default="",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default="",
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default="",
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    parser.add_argument("--num_parts_start",
                        default=-1,
                        type=int,
                        required=True,
                        help="Number of partitions to run train and test on")

    parser.add_argument("--num_parts_end",
                        default=-1,
                        type=int,
                        required=True,
                        help="Number of partitions to run train and test on")

    parser.add_argument("--task_num",
                        default=-1,
                        type=int,
                        required=True,
                        help="Number of partitions to run train and test on")

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "clinicalhedges": InputProcessor,
    }

    num_labels_task = {
        "clinicalhedges": [2, 2, 2, 2, 2],
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()
    task_num = args.task_num
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    print(processor)
    num_labels = num_labels_task[task_name][task_num - 1]
    print(num_labels)
    label_list = processor.get_labels(task_num - 1)
    print(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    file = open(
        os.path.join(args.output_dir,
                     "Classification_Reports_Task_{}.txt".format(task_num)),
        'w')

    for part_index in range(args.num_parts_start, args.num_parts_end):
        train_examples = None
        num_train_optimization_steps = None
        if args.do_train:
            train_examples = processor.get_train_examples(
                args.data_dir, part_index, task_num)
            num_train_optimization_steps = int(
                len(train_examples) / args.train_batch_size /
                args.gradient_accumulation_steps) * args.num_train_epochs
            if args.local_rank != -1:
                num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
                )

        # Prepare model
        cache_dir = args.cache_dir if args.cache_dir else os.path.join(
            str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
                args.local_rank))
        model = BertForSequenceClassification.from_pretrained(
            args.model_file, cache_dir=cache_dir, num_labels=num_labels)
        if args.fp16:
            model.half()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)

        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0
        if args.do_train:
            train_features = convert_examples_to_features(
                train_examples, label_list, args.max_seq_length, tokenizer)
            logger.info(
                "***** Running training on Part {}  Task {}*****".format(
                    part_index, task_num))
            logger.info("  Num examples = %d", len(train_examples))
            logger.info("  Batch size = %d", args.train_batch_size)
            logger.info("  Num steps = %d", num_train_optimization_steps)
            all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                         dtype=torch.long)
            all_input_mask = torch.tensor(
                [f.input_mask for f in train_features], dtype=torch.long)
            all_segment_ids = torch.tensor(
                [f.segment_ids for f in train_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
            train_data = TensorDataset(all_input_ids, all_input_mask,
                                       all_segment_ids, all_label_ids)
            if args.local_rank == -1:
                train_sampler = RandomSampler(train_data)
            else:
                train_sampler = DistributedSampler(train_data)
            train_dataloader = DataLoader(train_data,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size)

            for ep in trange(int(args.num_train_epochs), desc="Epoch"):
                model.train()
                tr_loss = 0
                nb_tr_examples, nb_tr_steps = 0, 0
                for step, batch in enumerate(train_dataloader):
                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_mask, segment_ids, label_ids = batch
                    loss = model(input_ids, segment_ids, input_mask, label_ids)
                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        optimizer.backward(loss)
                    else:
                        loss.backward()

                    tr_loss += loss.item()
                    nb_tr_examples += input_ids.size(0)
                    nb_tr_steps += 1
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        if args.fp16:
                            # modify learning rate with special warm up BERT uses
                            # if args.fp16 is False, BertAdam is used that handles this automatically
                            lr_this_step = args.learning_rate * warmup_linear(
                                global_step / num_train_optimization_steps,
                                args.warmup_proportion)
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr_this_step
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1

                eval_examples = processor.get_dev_examples(
                    args.data_dir, part_index, task_num)
                eval_features = convert_examples_to_features(
                    eval_examples, label_list, args.max_seq_length, tokenizer)
                print("\n")
                print("Running evaluation for epoch: {}".format(ep))
                all_input_ids = torch.tensor(
                    [f.input_ids for f in eval_features], dtype=torch.long)
                all_input_mask = torch.tensor(
                    [f.input_mask for f in eval_features], dtype=torch.long)
                all_segment_ids = torch.tensor(
                    [f.segment_ids for f in eval_features], dtype=torch.long)
                all_label_ids = torch.tensor(
                    [f.label_id for f in eval_features], dtype=torch.long)
                eval_data = TensorDataset(all_input_ids, all_input_mask,
                                          all_segment_ids, all_label_ids)
                # Run prediction for full data
                eval_sampler = SequentialSampler(eval_data)
                eval_dataloader = DataLoader(eval_data,
                                             sampler=eval_sampler,
                                             batch_size=args.eval_batch_size)

                model.eval()
                eval_loss, eval_accuracy = 0, 0
                nb_eval_steps, nb_eval_examples = 0, 0

                for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)

                    with torch.no_grad():
                        tmp_eval_loss = model(input_ids, segment_ids,
                                              input_mask, label_ids)
                        logits = model(input_ids, segment_ids, input_mask)

                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    tmp_eval_accuracy = accuracy(logits, label_ids)

                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy

                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

                eval_loss = eval_loss / nb_eval_steps
                eval_accuracy = eval_accuracy / nb_eval_examples
                loss = tr_loss / nb_tr_steps if args.do_train else None
                result = {
                    'eval_loss': eval_loss,
                    'eval_accuracy': eval_accuracy,
                    'global_step': global_step,
                    'loss': loss
                }

                for key in sorted(result.keys()):
                    print(key, str(result[key]))
                print()

        if args.do_train:
            # Save a trained model and the associated configuration
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            if (os.path.exists(
                    os.path.join(
                        args.output_dir,
                        "Model_Part_{}_Task_{}".format(part_index,
                                                       task_num)))):
                shutil.rmtree(
                    os.path.join(
                        args.output_dir,
                        "Model_Part_{}_Task_{}".format(part_index, task_num)))
            os.mkdir(
                os.path.join(
                    args.output_dir,
                    "Model_Part_{}_Task_{}".format(part_index, task_num)))
            output_model_file = os.path.join(
                args.output_dir,
                "Model_Part_{}_Task_{}".format(part_index,
                                               task_num), WEIGHTS_NAME)
            torch.save(model_to_save.state_dict(), output_model_file)
            output_config_file = os.path.join(
                args.output_dir,
                "Model_Part_{}_Task_{}".format(part_index,
                                               task_num), CONFIG_NAME)
            with open(output_config_file, 'w') as f:
                f.write(model_to_save.config.to_json_string())
        if args.do_eval:
            # Load a trained model and config that you have fine-tuned
            output_model_file = os.path.join(
                args.output_dir,
                "Model_Part_{}_Task_{}".format(part_index,
                                               task_num), WEIGHTS_NAME)
            output_config_file = os.path.join(
                args.output_dir,
                "Model_Part_{}_Task_{}".format(part_index,
                                               task_num), CONFIG_NAME)
            config = BertConfig(output_config_file)
            model = BertForSequenceClassification(config,
                                                  num_labels=num_labels)
            model.load_state_dict(
                torch.load(output_model_file, map_location='cpu'))
        model.to(device)

        if args.do_eval and (args.local_rank == -1
                             or torch.distributed.get_rank() == 0):
            eval_examples = processor.get_test_examples(
                args.data_dir, part_index, task_num)
            eval_features = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer)
            complete_user_ids = list()
            for example in eval_examples:
                complete_user_ids.append(example.guid)
            logger.info("***** Running Test for Part {} Task {}*****".format(
                part_index, task_num))
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)
            all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                         dtype=torch.long)
            all_input_mask = torch.tensor(
                [f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor(
                [f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)
            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_label_ids)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            complete_label_ids = list()
            complete_outputs = list()
            complete_probs = list()
            for input_ids, input_mask, segment_ids, label_ids in tqdm(
                    eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
                with torch.no_grad():
                    tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                          label_ids)
                    logits = model(input_ids, segment_ids, input_mask)
                last_layer_op = copy.deepcopy(logits)
                logits = logits.detach().cpu().numpy()
                sm = torch.nn.Softmax()
                probabilities = sm(last_layer_op)
                probabilities = probabilities.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                tmp_eval_accuracy = accuracy(logits, label_ids)
                outputs = np.argmax(logits, axis=1)
                complete_outputs.extend(outputs)
                complete_label_ids.extend(label_ids)
                complete_probs.extend(probabilities[:, 1])

                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy

                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

            outcsv = open(os.path.join(
                args.output_dir,
                "Reqd_Labels_Part_{}_Task_{}.csv".format(part_index,
                                                         task_num)),
                          'w',
                          encoding='utf8',
                          newline='')
            writer = csv.writer(outcsv, quotechar='"')
            writer.writerow(["ID", "True", "Pred", "Prob"])
            for user, true, pred, prob in zip(complete_user_ids,
                                              complete_label_ids,
                                              complete_outputs,
                                              complete_probs):
                writer.writerow([user, true, pred, prob])
            outcsv.close()
            eval_loss = eval_loss / nb_eval_steps
            eval_loss = eval_loss / nb_eval_steps

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples
            loss = tr_loss / nb_tr_steps if args.do_train else None
            result = {
                'eval_loss': eval_loss,
                'eval_accuracy': eval_accuracy,
                'global_step': global_step,
                'loss': loss
            }

            output_eval_file = os.path.join(args.output_dir,
                                            "eval_results.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))

            file.write(
                "\nClassification Report Part- {}\n\n".format(part_index) +
                classification_report(complete_label_ids, complete_outputs) +
                "\n\n\n")
    file.close()
Пример #19
0
def main():
    global best_prec1, args

    args.distributed = args.world_size > 1
    args.gpu = 0
    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()


    if args.distributed:
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](num_classes=args.num_classes)

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        model = DDP(model)

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
    else:
        master_params = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(master_params, args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    pipe = HybridPipe(batch_size=args.batch_size, num_threads=args.workers, device_id = args.rank, data_dir = traindir)
    pipe.build()
    test_run = pipe.run()
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    train_loader = DALIClassificationIterator(pipe, size = int(1281167 / args.world_size) )


    pipe = HybridPipe(batch_size=args.batch_size, num_threads=args.workers, device_id = args.rank, data_dir = valdir)
    pipe.build()
    test_run = pipe.run()
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    val_loader = DALIClassificationIterator(pipe, size = int(50000 / args.world_size) )

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
def main(pargs):

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    #set seed
    seed = 333

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        printr("Using GPUs", 0)
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)
    else:
        printr("Using CPUs", 0)
        device = torch.device("cpu")

    #visualize?
    visualize = (pargs.training_visualization_frequency >
                 0) or (pargs.validation_visualization_frequency > 0)

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        if visualize and not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

    # Setup WandB
    if (pargs.logging_frequency > 0) and (comm_rank == 0):
        # get wandb api token
        with open(os.path.join(pargs.wandb_certdir, ".wandbirc")) as f:
            token = f.readlines()[0].replace("\n", "").split()
            wblogin = token[0]
            wbtoken = token[1]
        # log in: that call can be blocking, it should be quick
        sp.call(["wandb", "login", wbtoken])

        #init db and get config
        resume_flag = pargs.run_tag if pargs.resume_logging else False
        wandb.init(entity=wblogin,
                   project='deepcam',
                   name=pargs.run_tag,
                   id=pargs.run_tag,
                   resume=resume_flag)
        config = wandb.config

        #set general parameters
        config.root_dir = root_dir
        config.output_dir = pargs.output_dir
        config.max_epochs = pargs.max_epochs
        config.local_batch_size = pargs.local_batch_size
        config.num_workers = comm_size
        config.channels = pargs.channels
        config.optimizer = pargs.optimizer
        config.start_lr = pargs.start_lr
        config.adam_eps = pargs.adam_eps
        config.weight_decay = pargs.weight_decay
        config.model_prefix = pargs.model_prefix
        config.amp_opt_level = pargs.amp_opt_level
        config.loss_weight_pow = pargs.loss_weight_pow
        config.lr_warmup_steps = pargs.lr_warmup_steps
        config.lr_warmup_factor = pargs.lr_warmup_factor

        # lr schedule if applicable
        if pargs.lr_schedule:
            for key in pargs.lr_schedule:
                config.update({"lr_schedule_" + key: pargs.lr_schedule[key]},
                              allow_val_change=True)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #restart from checkpoint if desired
    #if (comm_rank == 0) and (pargs.checkpoint):
    #load it on all ranks for now
    if pargs.checkpoint:
        checkpoint = torch.load(pargs.checkpoint, map_location=device)
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        net.load_state_dict(checkpoint['model'])
        if have_apex:
            amp.load_state_dict(checkpoint['amp'])
    else:
        start_step = 0
        start_epoch = 0

    #select scheduler
    if pargs.lr_schedule:
        scheduler_after = ph.get_lr_schedule(pargs.start_lr,
                                             pargs.lr_schedule,
                                             optimizer,
                                             last_step=start_step)

        if pargs.lr_warmup_steps > 0:
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=pargs.lr_warmup_factor,
                total_epoch=pargs.lr_warmup_steps,
                after_scheduler=scheduler_after)
        else:
            scheduler = scheduler_after

    #broadcast model and optimizer state
    steptens = torch.tensor(np.array([start_step, start_epoch]),
                            requires_grad=False).to(device)
    dist.broadcast(steptens, src=0)

    ##broadcast model and optimizer state
    #hvd.broadcast_parameters(net.state_dict(), root_rank = 0)
    #hvd.broadcast_optimizer_state(optimizer, root_rank = 0)

    #unpack the bcasted tensor
    start_step = steptens.cpu().numpy()[0]
    start_epoch = steptens.cpu().numpy()[1]

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    train_set = cam.CamDataset(train_dir,
                               statsfile=os.path.join(root_dir, 'stats.h5'),
                               channels=pargs.channels,
                               shuffle=True,
                               preprocess=True,
                               comm_size=comm_size,
                               comm_rank=comm_rank)
    train_loader = DataLoader(
        train_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps
    validation_dir = os.path.join(root_dir, "validation")
    validation_set = cam.CamDataset(validation_dir,
                                    statsfile=os.path.join(
                                        root_dir, 'stats.h5'),
                                    channels=pargs.channels,
                                    shuffle=(pargs.max_validation_steps
                                             is not None),
                                    preprocess=True,
                                    comm_size=comm_size,
                                    comm_rank=comm_rank)
    validation_loader = DataLoader(
        validation_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    #for visualization
    if visualize:
        viz = vizc.CamVisualizer()

    # Train network
    if (pargs.logging_frequency > 0) and (comm_rank == 0):
        wandb.watch(net)

    printr(
        '{:14.4f} REPORT: starting training'.format(
            dt.datetime.now().timestamp()), 0)
    step = start_step
    epoch = start_epoch
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    net.train()
    while True:

        printr(
            '{:14.4f} REPORT: starting epoch {}'.format(
                dt.datetime.now().timestamp(), epoch), 0)

        #for inputs_raw, labels, source in train_loader:
        for inputs, label, filename in train_loader:

            #send to device
            inputs = inputs.to(device)
            label = label.to(device)

            # forward pass
            outputs = net.forward(inputs)

            # Compute loss and average across nodes
            loss = criterion(outputs,
                             label,
                             weight=class_weights,
                             fpw_1=fpw_1,
                             fpw_2=fpw_2)

            # allreduce for loss
            loss_avg = loss.detach()
            dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)

            # Compute score
            predictions = torch.max(outputs, 1)[1]
            iou = utils.compute_score(predictions,
                                      label,
                                      device_id=device,
                                      num_classes=3)
            iou_avg = iou.detach()
            dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)

            # Backprop
            optimizer.zero_grad()
            if have_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            #step counter
            step += 1

            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #print some metrics
            printr(
                '{:14.4f} REPORT training: step {} loss {} iou {} LR {}'.
                format(dt.datetime.now().timestamp(), step,
                       loss_avg.item() / float(comm_size),
                       iou_avg.item() / float(comm_size), current_lr), 0)

            #visualize if requested
            if (step % pargs.training_visualization_frequency
                    == 0) and (comm_rank == 0):
                #extract sample id and data tensors
                sample_idx = np.random.randint(low=0, high=label.shape[0])
                plot_input = inputs.detach()[sample_idx, 0, ...].cpu().numpy()
                plot_prediction = predictions.detach()[sample_idx,
                                                       ...].cpu().numpy()
                plot_label = label.detach()[sample_idx, ...].cpu().numpy()

                #create filenames
                outputfile = os.path.basename(filename[sample_idx]).replace(
                    "data-", "training-").replace(".h5", ".png")
                outputfile = os.path.join(plot_dir, outputfile)

                #plot
                viz.plot(filename[sample_idx], outputfile, plot_input,
                         plot_prediction, plot_label)

                #log if requested
                if pargs.logging_frequency > 0:
                    img = Image.open(outputfile)
                    wandb.log(
                        {
                            "Training Examples": [
                                wandb.Image(
                                    img, caption="Prediction vs. Ground Truth")
                            ]
                        },
                        step=step)

            #log if requested
            if (pargs.logging_frequency > 0) and (
                    step % pargs.logging_frequency == 0) and (comm_rank == 0):
                wandb.log(
                    {"Training Loss": loss_avg.item() / float(comm_size)},
                    step=step)
                wandb.log({"Training IoU": iou_avg.item() / float(comm_size)},
                          step=step)
                wandb.log({"Current Learning Rate": current_lr}, step=step)

            # validation step if desired
            if (step % pargs.validation_frequency == 0):

                #eval
                net.eval()

                count_sum_val = torch.Tensor([0.]).to(device)
                loss_sum_val = torch.Tensor([0.]).to(device)
                iou_sum_val = torch.Tensor([0.]).to(device)

                # disable gradients
                with torch.no_grad():

                    # iterate over validation sample
                    step_val = 0
                    # only print once per eval at most
                    visualized = False
                    for inputs_val, label_val, filename_val in validation_loader:

                        #send to device
                        inputs_val = inputs_val.to(device)
                        label_val = label_val.to(device)

                        # forward pass
                        outputs_val = net.forward(inputs_val)

                        # Compute loss and average across nodes
                        loss_val = criterion(outputs_val,
                                             label_val,
                                             weight=class_weights)
                        loss_sum_val += loss_val

                        #increase counter
                        count_sum_val += 1.

                        # Compute score
                        predictions_val = torch.max(outputs_val, 1)[1]
                        iou_val = utils.compute_score(predictions_val,
                                                      label_val,
                                                      device_id=device,
                                                      num_classes=3)
                        iou_sum_val += iou_val

                        # Visualize
                        if (step_val % pargs.validation_visualization_frequency
                                == 0) and (not visualized) and (comm_rank
                                                                == 0):
                            #extract sample id and data tensors
                            sample_idx = np.random.randint(
                                low=0, high=label_val.shape[0])
                            plot_input = inputs_val.detach()[
                                sample_idx, 0, ...].cpu().numpy()
                            plot_prediction = predictions_val.detach()[
                                sample_idx, ...].cpu().numpy()
                            plot_label = label_val.detach()[sample_idx,
                                                            ...].cpu().numpy()

                            #create filenames
                            outputfile = os.path.basename(
                                filename[sample_idx]).replace(
                                    "data-",
                                    "validation-").replace(".h5", ".png")
                            outputfile = os.path.join(plot_dir, outputfile)

                            #plot
                            viz.plot(filename[sample_idx], outputfile,
                                     plot_input, plot_prediction, plot_label)
                            visualized = True

                            #log if requested
                            if pargs.logging_frequency > 0:
                                img = Image.open(outputfile)
                                wandb.log(
                                    {
                                        "Validation Examples": [
                                            wandb.Image(
                                                img,
                                                caption=
                                                "Prediction vs. Ground Truth")
                                        ]
                                    },
                                    step=step)

                        #increase eval step counter
                        step_val += 1

                        if (pargs.max_validation_steps is not None
                            ) and step_val > pargs.max_validation_steps:
                            break

                # average the validation loss
                dist.reduce(count_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(loss_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(iou_sum_val, dst=0, op=dist.ReduceOp.SUM)
                loss_avg_val = loss_sum_val.item() / count_sum_val.item()
                iou_avg_val = iou_sum_val.item() / count_sum_val.item()

                # print results
                printr(
                    '{:14.4f} REPORT validation: step {} loss {} iou {}'.
                    format(dt.datetime.now().timestamp(), step, loss_avg_val,
                           iou_avg_val), 0)

                # log in wandb
                if (pargs.logging_frequency > 0) and (comm_rank == 0):
                    wandb.log({"Validation Loss": loss_avg_val}, step=step)
                    wandb.log({"Validation IoU": iou_avg_val}, step=step)

                # set to train
                net.train()

            #save model if desired
            if (step % pargs.save_frequency == 0) and (comm_rank == 0):
                checkpoint = {
                    'step': step,
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                if have_apex:
                    checkpoint['amp'] = amp.state_dict()
                torch.save(
                    checkpoint,
                    os.path.join(
                        output_dir,
                        pargs.model_prefix + "_step_" + str(step) + ".cpt"))

        #do some after-epoch prep, just for the books
        epoch += 1
        if comm_rank == 0:

            # Save the model
            checkpoint = {
                'step': step,
                'epoch': epoch,
                'model': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            if have_apex:
                checkpoint['amp'] = amp.state_dict()
            torch.save(
                checkpoint,
                os.path.join(
                    output_dir,
                    pargs.model_prefix + "_epoch_" + str(epoch) + ".cpt"))

        #are we done?
        if epoch >= pargs.max_epochs:
            break

    printr(
        '{:14.4f} REPORT: finishing training'.format(
            dt.datetime.now().timestamp()), 0)
Пример #21
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default='data/conll2003/',
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default='NER',
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default='ner_output',
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run test on the test set.")
    parser.add_argument("--do_pred",
                        action='store_true',
                        help="Whether to run pred on the pred set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=4.0,#3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--clip',
                        type=float,
                        default=0.5,
                        help="gradient clipping")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    
    parser.add_argument('--text_a', type=str, default='', help="input text_a.")
    parser.add_argument('--text_b', type=str, default='', help="input text_b.")
    
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "ner": NerProcessor
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_pred:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels(args.data_dir)
    num_labels_task = {"ner": len(label_list)}
    num_labels = num_labels_task[task_name]

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        #train_examples = train_examples[:1000]
        print("train_examples :: ",len(list(train_examples)))
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
    model = BertForTokenClassification.from_pretrained(args.bert_model, cache_dir=cache_dir, num_labels=num_labels)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                
                # added clip
                if args.clip is not None:
                    _ = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForTokenClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        #model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
        # Load a trained model and config that you have fine-tuned
        print('for eval only......................')
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = BertForTokenClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    model.to(device)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        #import pdb;pdb.set_trace()
        print("dev_eaxmples :: ",len(list(eval_examples)))
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        predictions , true_labels = [], [] 
        #predictions1 , true_labels1 = [], []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            # get index till '[SEP]'
            #print("label_list index SEP : ",label_list.index('[SEP]'))
            pred_xx = [list(p) for p in np.argmax(logits, axis=2)]
            pred_xx = [i[:i.index(label_list.index('[SEP]'))]for i in pred_xx]
            label_ids_xx = [i[:i.index(label_list.index('[SEP]'))]for i in label_ids.tolist()]
            #print(label_ids_xx)
            #print(pred_xx)

            # new add
            tmp_s = [max(len(i), len(j)) for i,j in zip(label_ids_xx,pred_xx)]
            tmp_u = [(i+[31]*(k-len(i)) if len(i) !=k else i,j+[31]*(k-len(j)) if len(j) !=k else j) for i,j,k in zip(label_ids_xx,pred_xx,tmp_s)]
            tmp_d1 = [h[0] for h in tmp_u]
            tmp_d2 = [h[1] for h in tmp_u]

            #print([list(p) for p in np.argmax(logits, axis=2)][:5])
            #tmp_eval_accuracy = flat_accuracy(logits, label_ids)
            tmp_eval_accuracy = flat_accc(pred_xx, label_ids_xx)
            #tmp_eval_accuracy = flat_accc(tmp_d1, tmp_d2)
            predictions.extend(tmp_d2)
            true_labels.append(tmp_d1)
            #predictions1.extend(pred_xx)
            #true_labels1.append(label_ids_xx)
            
            #print("tmp accuracy : ",tmp_eval_accuracy)
            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy
            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_steps
        loss = tr_loss/nb_tr_steps if args.do_train else None

        pred_tags = [[label_list[p_i] if p_i!=31 else 'XXX' for p_i in p] for p in predictions]
        valid_tags = [[label_list[l_ii] if l_ii!=31 else 'YYY' for l_ii in l_i] for l in true_labels for l_i in l ]
        print("valid_tags : ",valid_tags[:10])
        print("pred_tags : ",pred_tags[:10])
        print("Validation F1-Score: {}".format(f1_score(valid_tags, pred_tags)))
        print("Validation accuracy_score : {}".format(accuracy_score(valid_tags, pred_tags)))
        print("Validation classification_report : {}".format(classification_report(valid_tags, pred_tags)))
        
        #print("X Validation F1-Score: {}".format(f1_score(true_labels1, predictions1)))
        #print("X Validation accuracy_score : {}".format(accuracy_score(true_labels1, predictions1)))
        #print("X Validation classification_report : {}".format(classification_report(true_labels1, predictions1)))


        result = {'eval_loss': eval_loss,
                  'eval_accuracy': eval_accuracy,
                  'global_step': global_step,
                  'loss': loss}
        print(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            for key in sorted(result.keys()):
                writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_test and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_test_examples(args.data_dir)
        print('test examples len : {}'.format(len(eval_examples)))
        #import pdb;pdb.set_trace()
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        test_loss, test_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        predictions , true_labels = [], [] 

        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            # get index till '[SEP]'
            #print("label_list index SEP : ",label_list.index('[SEP]'))
            pred_xx = [list(p) for p in np.argmax(logits, axis=2)]
            pred_xx = [i[:i.index(label_list.index('[SEP]'))]for i in pred_xx]
            label_ids_xx = [i[:i.index(label_list.index('[SEP]'))]for i in label_ids.tolist()]
            #print(label_ids_xx)
            #print(pred_xx)

            # new add
            tmp_s = [max(len(i), len(j)) for i,j in zip(label_ids_xx,pred_xx)]
            tmp_u = [(i+[31]*(k-len(i)) if len(i) !=k else i,j+[31]*(k-len(j)) if len(j) !=k else j) for i,j,k in zip(label_ids_xx,pred_xx,tmp_s)]
            tmp_d1 = [h[0] for h in tmp_u]
            tmp_d2 = [h[1] for h in tmp_u]

            #print([list(p) for p in np.argmax(logits, axis=2)][:5])
            #tmp_eval_accuracy = flat_accuracy(logits, label_ids)
            tmp_eval_accuracy = flat_accc(pred_xx, label_ids_xx)
            #tmp_eval_accuracy = flat_accc(tmp_d1, tmp_d2)
            predictions.extend(tmp_d2)
            true_labels.append(tmp_d1)
            #print("tmp accuracy : ",tmp_eval_accuracy)
            test_loss += tmp_eval_loss.mean().item()
            test_accuracy += tmp_eval_accuracy
            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        test_loss = test_loss / nb_eval_steps
        test_accuracy = test_accuracy / nb_eval_steps
        loss = tr_loss/nb_tr_steps if args.do_train else None

        pred_tags = [[label_list[p_i] if p_i!=31 else 'XXX' for p_i in p] for p in predictions]
        valid_tags = [[label_list[l_ii] if l_ii!=31 else 'YYY' for l_ii in l_i] for l in true_labels for l_i in l ]
        print("valid_tags : ",valid_tags[:10])
        print("pred_tags : ",pred_tags[:10])
        print("Test F1-Score: {}".format(f1_score(valid_tags, pred_tags)))
        print("Test accuracy_score : {}".format(accuracy_score(valid_tags, pred_tags)))
        print("Test classification_report : {}".format(classification_report(valid_tags, pred_tags)))
        
        #print("X Test F1-Score: {}".format(f1_score(true_labels, predictions)))
        #print("X Test accuracy_score : {}".format(accuracy_score(true_labels, predictions)))
        #print("X Test classification_report : {}".format(classification_report(true_labels, predictions)))


        result = {'test_loss': test_loss,
                  'test_accuracy': test_accuracy,
                  'global_step': global_step,
                  'loss': loss}
        print(result)
        output_test_file = os.path.join(args.output_dir, "test_results.txt")
        with open(output_test_file, "w") as writer:
            for key in sorted(result.keys()):
                writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_pred and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        #eval_examples = processor.get_dev_examples(args.data_dir)
        model.eval()
        while True:
            print('enter a text to get NER. otherwise press Ctrl+C to close session.')
            text_a = input('>>>')
            #"Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday . ."
            eval_examples = {'text_a':text_a,'text_b':"The foodservice pie business does not fit our long-term growth strategy .",'label':'1','guid':'12345'}

            eval_features = convert_examples_to_features_test(eval_examples, label_list, args.max_seq_length, tokenizer)
            
            all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
            all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
            eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

            #model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            predictions , true_labels = [], [] 

            for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                    logits = model(input_ids, segment_ids, input_mask)

                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()

                pred_xx = [list(p) for p in np.argmax(logits, axis=2)]
                pred_xx = [i[:i.index(label_list.index('[SEP]'))] for i in pred_xx]

                print(pred_xx)
                print([[label_list[p_i] if p_i!=31 else 'XXX' for p_i in p] for p in pred_xx]) 
Пример #22
0
def main():
    global best_prec1, args

    args.gpu = 0
    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()

    if args.distributed:
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           num_classes=args.num_classes)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](num_classes=args.num_classes)

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        model = DDP(model)

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
    else:
        master_params = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(master_params,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    if len(args.data) == 1:
        traindir = os.path.join(args.data[0], 'train')
        valdir = os.path.join(args.data[0], 'val')
    else:
        traindir = args.data[0]
        valdir = args.data[1]

    pipe = HybridPipe(batch_size=args.batch_size,
                      num_threads=args.workers,
                      device_id=args.rank,
                      data_dir=traindir)
    pipe.build()
    test_run = pipe.run()
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    train_loader = DALIClassificationIterator(pipe,
                                              size=int(1281167 /
                                                       args.world_size))

    pipe = HybridPipe(batch_size=args.batch_size,
                      num_threads=args.workers,
                      device_id=args.rank,
                      data_dir=valdir)
    pipe.build()
    test_run = pipe.run()
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    val_loader = DALIClassificationIterator(pipe,
                                            size=int(50000 / args.world_size))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)

        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()
Пример #23
0
resume_epoch = train_config['resume_epoch']
epochs = train_config['epoch_num']
learning_rate = 0.
# yolo = init_model(config_map)
yolo = YOLO(config_map, logger=logger, vis=my_vis).to(device)
optimizer = optim.SGD(yolo.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)


# yolo_p = nn.parallel.DistributedDataParallel(yolo.to(device), device_ids=train_config['gpu_ids'])
if config_map.fp16_training:
    yolo, optimizer = amp.initialize(yolo, optimizer, opt_level='O1', loss_scale=128.0)
    yolo_p = DDP(yolo)
else:
    yolo_p = nn.parallel.DistributedDataParallel(yolo, device_ids=train_config['gpu_ids'])
if train_config['resume_from_path']:
    yolo_p.load_state_dict(torch.load(train_config['resume_from_path']))



# optimizer = optim.SGD(yolo_p.parameters(), lr=lr0, momentum=momentum, weight_decay=weight_decay) # , weight_decay=5e-4)
# optimizer = optim.Adam(yolo_p.parameters())

# yolo_p.load_state_dict(torch.load('densenet_sgd_S7_yolo.pth'))

yolo_p.module.train()
print(yolo_p)


train_dataset = yoloDataset(list_file=train_config['train_txt_path'], train=False, little_train=False, input_size=config_map.input_size, test_mode=False)
train_loader = DataLoader(train_dataset, batch_size=train_config['batch_size'], shuffle=True, num_workers=train_config['worker_num'], collate_fn=train_dataset.collate_fn)
Пример #24
0
def main():
    global best_prec1, args

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    if not os.path.isdir(args.checkpoint) and args.local_rank == 0:
        mkdir_p(args.checkpoint)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        # shared param/delay all reduce turns off bucketing in DDP, for lower latency runs this can improve perf
        # for the older version of APEX please use shared_param, for newer one it is delay_allreduce
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   verbose=False)

    # optionally resume from a checkpoint
    title = 'ImageNet-' + args.arch
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if args.local_rank == 0:
                logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                                title=title,
                                resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        if args.local_rank == 0:
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title)
            logger.set_names([
                'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
                'Valid Acc.', 'Valid Top5.'
            ])

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if (args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320  # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    pipe = HybridTrainPipe(batch_size=args.batch_size,
                           num_threads=args.workers,
                           device_id=args.local_rank,
                           data_dir=traindir,
                           crop=crop_size,
                           dali_cpu=args.dali_cpu)
    pipe.build()
    train_loader = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    pipe = HybridValPipe(batch_size=args.batch_size,
                         num_threads=args.workers,
                         device_id=args.local_rank,
                         data_dir=valdir,
                         crop=crop_size,
                         size=val_size)
    pipe.build()
    val_loader = DALIClassificationIterator(
        pipe, size=int(pipe.epoch_size("Reader") / args.world_size))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    total_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        adjust_learning_rate(optimizer, epoch, args)

        if args.local_rank == 0:
            print('\nEpoch: [%d | %d] LR: %f' %
                  (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        [train_loss, train_acc,
         avg_train_time] = train(train_loader, model, criterion, optimizer,
                                 epoch)
        total_time.update(avg_train_time)
        # evaluate on validation set
        [test_loss, prec1, prec5] = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            # append logger file
            logger.append([
                optimizer.param_groups[0]['lr'], train_loss, test_loss,
                train_acc, prec1, prec5
            ])

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                checkpoint=args.checkpoint)
            if epoch == args.epochs - 1:
                print('##Top-1 {0}\n'
                      '##Top-5 {1}\n'
                      '##Perf  {2}'.format(
                          prec1, prec5,
                          args.total_batch_size / total_time.avg))

        # reset DALI iterators
        train_loader.reset()
        val_loader.reset()

    if args.local_rank == 0:
        logger.close()
Пример #25
0
class face_learner(object):
    def __init__(self, conf, args, inference=False):
        print(conf)
        self.local_rank = args.local_rank
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).cuda()
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).cuda()

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + [self.head.kernel],
                        'weight_decay': 4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            else:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            print(self.optimizer)
            #[self.model, self.head], self.optimizer = amp.initialize([self.model, self.head], self.optimizer, opt_level='O1')
            [self.model, self.head
             ], self.optimizer = amp.initialize([self.model, self.head],
                                                self.optimizer,
                                                opt_level='O3',
                                                keep_batchnorm_fp32=True)
            print(self.optimizer, args.local_rank)
            self.head = DistributedDataParallel(self.head)
            self.model = DistributedDataParallel(self.model)
            #self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank])
            #             self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 10
            self.save_every = len(self.loader) // 5
            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(
                self.loader.dataset.root.parent)
        else:
            self.threshold = conf.threshold

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(save_path / 'model_{}'.format(fixed_str)))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)
#         self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
#         self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
#         self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                       nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            batch_num += 1

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            loss = conf.ce_loss(thetas, labels)

            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step

            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses

    def train(self, conf, epochs):
        self.model.train()
        #conf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #self.model = torch.nn.DataParallel(self.model,device_ids=[0,1,2,3,4,5,6,7])
        #self.model.to(conf.device)
        #self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank])
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            tic = time.time()
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in tqdm(iter(self.loader)):
                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                    running_loss += scaled_loss.item()
                #loss.backward()
                #running_loss += loss.item()
                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.agedb_30, self.agedb_30_issame)
                    self.board_val('agedb_30', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.cfp_fp, self.cfp_fp_issame)
                    self.board_val('cfp_fp', accuracy, best_threshold,
                                   roc_curve_tensor)
                    self.model.train()
                if self.step % self.save_every == 0 and self.step != 0:
                    self.save_state(conf, accuracy)

                self.step += 1
            toc = time.time()
            print('epoch {} time'.format(e), toc - tic)
        self.save_state(conf, accuracy, to_save_folder=True, extra='final')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum
Пример #26
0
def main():
    global best_prec1, args

    args.distributed = args.world_size > 1
    args.gpu = 0
    if args.distributed:
        args.gpu = args.rank % torch.cuda.device_count()
        

    if args.distributed:
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend=args.dist_backend, 
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        #shared param turns off bucketing in DDP, for lower latency runs this can improve perf
        model = DDP(model, shared_param=True)

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
    else:
        master_params = list(model.parameters())

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(master_params, args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if(args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            # transforms.ToTensor(), Too slow
            # normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True,
        collate_fn=fast_collate)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
Пример #27
0
def main():
    # def main(args):
    parser = setup_parser()
    args = parser.parse_args()

    # specifies the path where the biobert or clinical bert model is saved
    if args.bert_model == 'biobert' or args.bert_model == 'clinical_bert' or args.bert_model == 'stroke_bert':
        args.bert_model = args.model_loc

    print(args.bert_model)

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "mednli": MedNLIProcessor,
        "carotid": CaroditProcessor
    }

    num_labels_task = {
        "cola": 2,
        "mnli": 3,
        "mrpc": 2,
        "mednli": 3,
        "carotid": 17
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    print('TRAIN')
    train = processor.get_train_examples(args.data_dir)
    print([(train[i].text_a, train[i].text_b, train[i].label)
           for i in range(3)])
    print('DEV')
    dev = processor.get_dev_examples(args.data_dir)
    print([(dev[i].text_a, dev[i].text_b, dev[i].label) for i in range(3)])
    print('TEST')
    test = processor.get_test_examples(args.data_dir)
    print([(test[i].text_a, test[i].text_b, test[i].label) for i in range(3)])

    train_examples = None
    num_train_optimization_steps = -1
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(
            args.local_rank))
    if task_name == 'carotid':
        model = BertForMultiLabelSequenceClassification.from_pretrained(
            args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, task_name)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        if task_name == 'carotid':
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.float)
        else:
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * WarmupLinearSchedule(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        if task_name == 'carotid':
            model = BertForMultiLabelSequenceClassification(
                config, num_labels=num_labels)
        else:
            model = BertForSequenceClassification(config,
                                                  num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        if task_name == 'carotid':
            model = BertForMultiLabelSequenceClassification.from_pretrained(
                args.bert_model, num_labels=num_labels)
        else:
            model = BertForSequenceClassification.from_pretrained(
                args.bert_model, num_labels=num_labels)
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, task_name)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        if task_name == 'carotid':
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.float)
        else:
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        all_logits = None
        all_labels = None

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            if task_name == 'carotid':
                if all_logits is None:
                    all_logits = logits.detach().cpu().numpy()
                else:
                    all_logits = np.concatenate(
                        (all_logits, logits.detach().cpu().numpy()), axis=0)

                if all_labels is None:
                    all_labels = label_ids.detach().cpu().numpy()
                else:
                    all_labels = np.concatenate(
                        (all_labels, label_ids.detach().cpu().numpy()), axis=0)
            else:
                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()

                tmp_eval_accuracy = accuracy(logits, label_ids)

                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy

                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

        if task_name == 'carotid':
            fpr = dict()
            tpr = dict()
            roc_auc = dict()
            for i in range(num_labels):
                fpr[i], tpr[i], _ = roc_curve(all_labels[:, i], all_logits[:,
                                                                           i])
                roc_auc[i] = auc(fpr[i], tpr[i])
            # Compute micro-average ROC curve and ROC area
            fpr["micro"], tpr["micro"], _ = roc_curve(all_labels.ravel(),
                                                      all_logits.ravel())
            roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

            save_path = os.path.join(args.output_dir, "eval_prediction.pickle")
            predic_result = {
                'all_logits': all_logits,
                'all_labels': all_labels
            }
            with open(save_path, 'wb') as file_pi:
                pickle.dump(predic_result, file_pi)

            result = {'eval_loss': eval_loss, 'roc_auc': roc_auc}
        else:
            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples
            loss = tr_loss / nb_tr_steps if args.do_train else None
            result = {
                'eval_loss': eval_loss,
                'eval_accuracy': eval_accuracy,
                'global_step': global_step,
                'loss': loss
            }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_test and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        test_examples = processor.get_test_examples(args.data_dir)
        test_features = convert_examples_to_features(test_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, task_name)
        logger.info("***** Running testing *****")
        logger.info("  Num examples = %d", len(test_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)
        if task_name == 'carotid':
            all_label_ids = torch.tensor([f.label_id for f in test_features],
                                         dtype=torch.float)
        else:
            all_label_ids = torch.tensor([f.label_id for f in test_features],
                                         dtype=torch.long)
        test_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)

        all_logits = None
        all_labels = None

        model.eval()
        test_loss, test_accuracy = 0, 0
        nb_test_steps, nb_test_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                test_dataloader, desc="Testing"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_test_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            if task_name == 'carotid':
                if all_logits is None:
                    all_logits = logits.detach().cpu().numpy()
                else:
                    all_logits = np.concatenate(
                        (all_logits, logits.detach().cpu().numpy()), axis=0)

                if all_labels is None:
                    all_labels = label_ids.detach().cpu().numpy()
                else:
                    all_labels = np.concatenate(
                        (all_labels, label_ids.detach().cpu().numpy()), axis=0)
            else:
                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                tmp_test_accuracy = accuracy(logits, label_ids)

                test_loss += tmp_test_loss.mean().item()
                test_accuracy += tmp_test_accuracy

                nb_test_examples += input_ids.size(0)
                nb_test_steps += 1

        if task_name == 'carotid':
            fpr = dict()
            tpr = dict()
            roc_auc = dict()
            for i in range(num_labels):
                fpr[i], tpr[i], _ = roc_curve(all_labels[:, i], all_logits[:,
                                                                           i])
                roc_auc[i] = auc(fpr[i], tpr[i])
            # Compute micro-average ROC curve and ROC area
            fpr["micro"], tpr["micro"], _ = roc_curve(all_labels.ravel(),
                                                      all_logits.ravel())
            roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

            save_path = os.path.join(args.output_dir, "test_prediction.pickle")
            predic_result = {
                'all_logits': all_logits,
                'all_labels': all_labels
            }
            with open(save_path, 'wb') as file_pi:
                pickle.dump(predic_result, file_pi)

            result = {'test_loss': test_loss, 'roc_auc': roc_auc}
        else:
            test_loss = test_loss / nb_test_steps
            test_accuracy = test_accuracy / nb_test_examples
            loss = tr_loss / nb_tr_steps if args.do_train else None
            result = {
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'global_step': global_step,
                'loss': loss
            }

        output_test_file = os.path.join(args.output_dir, "test_results.txt")
        with open(output_test_file, "w") as writer:
            logger.info("***** Test results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    ## Other parameters
    parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default=None, type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
    parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
                             "of training.")
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json "
                             "output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--version_2_with_negative',
                        action='store_true',
                        help='If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument('--null_score_diff_threshold',
                        type=float, default=0.0,
                        help="If null_score - best_non_null is greater than the threshold predict null.")
    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError("At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    model = BertForQuestionAnsweringNew.from_pretrained(args.bert_model,
                cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
            list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
        train_features = None
        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except:
            train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s", cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                   all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                if n_gpu == 1:
                    batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions = batch
                loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used and handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForQuestionAnsweringNew(config)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForQuestionAnsweringNew.from_pretrained(args.bert_model)

    model.to(device)

    if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = read_squad_examples(
            input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
            if len(all_results) % 1000 == 0:
                logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))
        output_prediction_file = os.path.join(args.output_dir, "predictions.json")
        output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
        write_predictions(eval_examples, eval_features, all_results,
                          args.n_best_size, args.max_answer_length,
                          args.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                          args.version_2_with_negative, args.null_score_diff_threshold)
Пример #29
0
class ModelAndLoss(nn.Module):
    def __init__(self,
                 arch,
                 loss,
                 pretrained_weights=None,
                 cuda=True,
                 fp16=False,
                 width=1.0,
                 n_struct_layers=0,
                 struct='D',
                 softmax_struct='D',
                 sm_pooling=1,
                 groups=8,
                 shuffle='P'):
        super(ModelAndLoss, self).__init__()
        self.arch = arch

        print("=> creating model '{}'".format(arch))
        # model = models.build_resnet(arch[0], arch[1])
        if arch == 'mobilenetv1':
            model = MobileNet(width_mult=width,
                              structure=[struct] * n_struct_layers,
                              softmax_structure=softmax_struct,
                              sm_pooling=sm_pooling)
            # if args.distilled_param_path:
            #     model.load_state_dict(model.mixed_model_state_dict(args.full_model_path, args.distilled_param_path))
        elif arch == 'shufflenetv1':
            model = ShuffleNet(width_mult=width,
                               groups=groups,
                               shuffle=shuffle)
        else:
            model = models.__dict__[arch]()
        if pretrained_weights is not None:
            print("=> using pre-trained model from a file '{}'".format(arch))
            model.load_state_dict(pretrained_weights)

        if cuda:
            model = model.cuda()
        if fp16:
            model = network_to_half(model)

        # define loss function (criterion) and optimizer
        criterion = loss()

        if cuda:
            criterion = criterion.cuda()

        self.model = model
        self.loss = criterion

    def forward(self, data, target):
        output = self.model(data)
        if hasattr(self, '_teacher_model'):
            with torch.no_grad():
                teacher_output = self._teacher_model(data)
            loss = self.loss(output, teacher_output, target)
        else:
            loss = self.loss(output, target)

        return loss, output

    def distributed(self):
        self.model = DDP(self.model)

    def load_model_state(self, state):
        if not state is None:
            self.model.load_state_dict(state)
Пример #30
0
def main():
    parser = argparse.ArgumentParser()
    #drive.mount('/content/gdrive')
    swagDir = './data'
    cacheDir = './cache/'
    saveDir = './save/cache/'


    ## Required parameters
    parser.add_argument("--data_dir",
                        default=swagDir,
                        type=str,
                        #required=True,
                        help="The input data dir. Should contain the .csv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default="bert-base-uncased", type=str,
                        #required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--output_dir",
                        default=saveDir,
                        type=str,
                        #required=True,
                        help="The output directory where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=100,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        default = True,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default = True,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()
    device = torch.device("cuda")
    n_gpu = torch.cuda.device_count()
    """
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    """
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    model = BertForMultipleChoice.from_pretrained(args.bert_model,
        cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
        num_choices=4)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, tokenizer, args.max_seq_length, True)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            #for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.fp16 and args.loss_scale != 1.0:
                    # rescale loss for fp16 training
                    # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
                    loss = loss * args.loss_scale
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1


    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForMultipleChoice(config, num_choices=4)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
    model.to(device)


    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
        eval_features = convert_examples_to_features(
            eval_examples, tokenizer, args.max_seq_length, True)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
        all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
        all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        result = {'eval_loss': eval_loss,
                  'eval_accuracy': eval_accuracy,
                  'global_step': global_step,
                  'loss': tr_loss/nb_tr_steps}

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Пример #31
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default='../../data/eng-2015.conll',
                        type=str,
                        required=True,
                        help="train file path")
    parser.add_argument("--dev_file",
                        default='../../data/eng-2016.conll',
                        type=str,
                        required=True,
                        help="dev file path")

    parser.add_argument(
        "--bert_model",
        default='bert-base-cased',
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")

    parser.add_argument("--finetune_dir",
                        default='NER_BERT',
                        type=str,
                        required=False,
                        help="The output")

    parser.add_argument(
        "--output_dir",
        default='NER_BERT',
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_finetune",
                        action='store_true',
                        help="Whether to run finetuning.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case,
                                              do_basic_tokenize=False)
    label_list = get_labels()
    num_labels = len(label_list)
    train_examples = read_ner_example(args.train_file, args.do_lower_case)
    num_train_optimization_steps = None
    if args.do_train:
        #train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    model = BertForTokenClassification.from_pretrained(args.bert_model,
                                                       cache_dir=cache_dir,
                                                       num_labels=num_labels)
    if args.fp16:
        model.half()
    if args.do_finetune:
        if not os.path.exists(args.finetune_dir) and not os.listdir(
                args.finetune_dir):
            raise ValueError("Finetune directory ({}) is empty.".format(
                args.finetune_dir))
        finetune_model_file = os.path.join(args.finetune_dir, WEIGHTS_NAME)
        finetune_config_file = os.path.join(args.finetune_dir, CONFIG_NAME)
        config = BertConfig(finetune_config_file)
        #model = BertForTokenClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(finetune_model_file))
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        all_label_masks = torch.tensor([f.label_mask for f in train_features],
                                       dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids,
                                   all_label_masks)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids, label_masks = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids,
                             label_masks)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForTokenClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = BertForTokenClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
        #model = BertForTokenClassification.from_pretrained(args.bert_model, num_labels=num_labels)
    model.to(device)
    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = read_ner_example(args.dev_file, args.do_lower_case)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        all_label_masks = torch.tensor([f.label_mask for f in eval_features],
                                       dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids,
                                  all_label_masks)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        pred_list = []
        label_list = []
        for input_ids, input_mask, segment_ids, label_ids, label_masks in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            label_masks = label_masks.to(device)
            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids, label_masks)
                logits = model(input_ids, segment_ids, input_mask)
            active_loss = label_masks.view(-1) == 1
            active_logits = logits.view(-1, num_labels)[active_loss]
            #print(active_logits.shape)
            active_labels = label_ids.view(-1)[active_loss]
            active_logits = active_logits.detach().cpu().numpy()
            #print(active_logits.shape)
            active_labels = active_labels.to('cpu').numpy()
            active_preds = np.argmax(active_logits, axis=1)
            #print(active_labels.shape, active_preds.shape)
            #tmp_eval_accuracy = accuracy(logits, label_ids, label_masks)

            #eval_loss += tmp_eval_loss.mean().item()
            #eval_accuracy += tmp_eval_accuracy
            pred_list.extend(active_preds)
            label_list.extend(active_labels)
            #print(active_labels.shape)
            nb_eval_examples += active_labels.shape[0]
            nb_eval_steps += 1

        #eval_loss = eval_loss / nb_eval_steps
        #eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        eval_f1_micro = f1_score(label_list, pred_list, average='micro')
        eval_f1_none = f1_score(label_list, pred_list, average=None)
        result = {
            'eval_f1_micro': eval_f1_micro,
            'eval_f1_none': eval_f1_none,
            'global_step': global_step,
            'loss': loss
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
        output_pred_file = os.path.join(args.output_dir, "pred_results.conll")
        label_map = get_labels()
        print(len(label_list), len(pred_list))
        with open(output_pred_file, 'w') as f, open(args.dev_file) as dev_f:
            idx = 1
            for l, p, dl in zip(label_list, pred_list, dev_f):
                if len(dl) == 0:
                    print(dl)
                    f.write('\n')
                    idx = 1
                    continue
                f.write(' '.join((str(idx), label_map[l], label_map[p])) +
                        '\n')
                idx += 1