Пример #1
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 requires cudnn backend to be enabled."
    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # create model
    if args.pretrained:
        if args.local_rank == 0:
            print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch.startswith('aognet'):
            cfg.merge_from_file(os.path.join(args.save_dir, 'config.yaml'))

            model = aognet_m() if args.arch == 'aognet_m' else aognet_s()
            checkpoint = torch.load(
                os.path.join(args.save_dir, 'model_best.pth.tar'))
            # model.load_state_dict(checkpoint['state_dict'])
        elif args.arch.startswith('resnet'):
            model = resnets.__dict__[args.arch](pretrained=True)
        elif args.arch.startswith('mobilenet'):
            model = mobilenets.__dict__[args.arch](pretrained=True)
        else:
            raise NotImplementedError("Unkown network arch.")
    else:
        if args.local_rank == 0:
            print("=> creating {}".format(args.arch))
        # update args
        cfg.merge_from_file(args.cfg)
        args.batch_size = cfg.batch_size
        args.lr = cfg.lr
        args.momentum = cfg.momentum
        args.weight_decay = cfg.wd
        args.nesterov = cfg.nesterov
        args.epochs = cfg.num_epoch
        if args.arch.startswith('aognet'):
            model = aognet_m() if args.arch == 'aognet_m' else aognet_s()
        elif args.arch.startswith('resnet'):
            model = resnets.__dict__[args.arch](
                zero_init_residual=cfg.norm_zero_gamma_init,
                num_classes=cfg.num_classes,
                replace_stride_with_dilation=cfg.resnet.
                replace_stride_with_dilation,
                dataset=cfg.dataset,
                base_inplanes=cfg.resnet.base_inplanes,
                imagenet_head7x7=cfg.stem.imagenet_head7x7,
                stem_kernel_size=cfg.stem.stem_kernel_size,
                stem_stride=cfg.stem.stem_stride,
                norm_name=cfg.norm_name,
                norm_groups=cfg.norm_groups,
                norm_k=cfg.norm_k,
                norm_attention_mode=cfg.norm_attention_mode,
                norm_all_mix=cfg.norm_all_mix,
                extra_norm_ac=cfg.resnet.extra_norm_ac,
                replace_stride_with_avgpool=cfg.resnet.
                replace_stride_with_avgpool)
        elif args.arch.startswith('MobileNetV3'):
            model = mobilenetsv3.__dict__[args.arch](
                norm_name=cfg.norm_name,
                norm_groups=cfg.norm_groups,
                norm_k=cfg.norm_k,
                norm_attention_mode=cfg.norm_attention_mode,
                rm_se=cfg.mobilenet.rm_se,
                use_mn_in_se=cfg.mobilenet.use_mn_in_se)
        elif args.arch.startswith('mobilenet'):
            model = mobilenets.__dict__[args.arch](
                norm_name=cfg.norm_name,
                norm_groups=cfg.norm_groups,
                norm_k=cfg.norm_k,
                norm_attention_mode=cfg.norm_attention_mode)
        elif args.arch.startswith('densenet'):
            model = densenets.__dict__[args.arch](
                num_classes=cfg.num_classes,
                imagenet_head7x7=cfg.stem.imagenet_head7x7,
                norm_name=cfg.norm_name,
                norm_groups=cfg.norm_groups,
                norm_k=cfg.norm_k,
                norm_attention_mode=cfg.norm_attention_mode)
        else:
            raise NotImplementedError("Unkown network arch.")

    if args.local_rank == 0:
        if cfg.dataset.startswith('cifar'):
            H, W = 32, 32
        elif cfg.dataset.startswith('imagenet'):
            H, W = 224, 224
        else:
            raise NotImplementedError("Unknown dataset")
        flops, params = thop_profile(copy.deepcopy(model),
                                     input_size=(1, 3, H, W))
        print('=> FLOPs: {:.6f}G, Params: {:.6f}M'.format(
            flops / 1e9, params / 1e6))
        print('=> Params (double-check): %.6fM' %
              (sum(p.numel() for p in model.parameters()) / 1e6))

    if args.sync_bn:
        import apex
        if args.local_rank == 0:
            print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    if args.fp16:
        model = FP16Model(model)
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    if args.pretrained:
        model.load_state_dict(checkpoint['state_dict'])

    # Scale learning rate based on global batch size
    args.lr = args.lr * float(
        args.batch_size *
        args.world_size) / cfg.lr_scale_factor  #TODO: control the maximum?

    if args.remove_norm_weight_decay:
        if args.local_rank == 0:
            print("=> ! Weight decay NOT applied to FeatNorm parameters ")
        norm_params = set()  #TODO: need to check this via experiments
        rest_params = set()
        for m in model.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, MixtureBatchNorm2d,
                              MixtureGroupNorm)):
                for param in m.parameters(False):
                    norm_params.add(param)
            else:
                for param in m.parameters(False):
                    rest_params.add(param)

        optimizer = torch.optim.SGD([{
            'params': list(norm_params),
            'weight_decay': 0.0
        }, {
            'params': list(rest_params)
        }],
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
    else:
        if args.local_rank == 0:
            print("=> ! Weight decay applied to FeatNorm parameters ")
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale)

    # define loss function (criterion) and optimizer
    criterion_train = nn.CrossEntropyLoss().cuda() if cfg.dataaug.labelsmoothing_rate == 0.0 \
                        else LabelSmoothing(cfg.dataaug.labelsmoothing_rate).cuda()
    criterion_val = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                if args.local_rank == 0:
                    print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                if args.local_rank == 0:
                    print("=> loaded checkpoint '{}' (epoch {})".format(
                        args.resume, checkpoint['epoch']))
            else:
                if args.local_rank == 0:
                    print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    # Data loading code
    lr_milestones = None
    if cfg.dataset == "cifar10":
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ])
        train_dataset = datasets.CIFAR10('./datasets',
                                         train=True,
                                         download=False,
                                         transform=train_transform)
        val_dataset = datasets.CIFAR10('./datasets',
                                       train=False,
                                       download=False)
        lr_milestones = cfg.lr_milestones
    elif cfg.dataset == "cifar100":
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ])
        train_dataset = datasets.CIFAR100('./datasets',
                                          train=True,
                                          download=False,
                                          transform=train_transform)
        val_dataset = datasets.CIFAR100('./datasets',
                                        train=False,
                                        download=False)
        lr_milestones = cfg.lr_milestones
    elif cfg.dataset == "imagenet":
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        crop_size = cfg.crop_size  # 224
        val_size = cfg.crop_size + 32  # 256

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(
                    crop_size, interpolation=cfg.crop_interpolation),
                transforms.RandomHorizontalFlip(),
                # transforms.ToTensor(), Too slow
                # normalize,
            ]))
        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(val_size,
                                  interpolation=cfg.crop_interpolation),
                transforms.CenterCrop(crop_size),
            ]))

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             collate_fn=fast_collate)

    if args.evaluate:
        validate(val_loader, model, criterion_val)
        return

    scheduler = CosineAnnealingLR(
        optimizer.optimizer if args.fp16 else optimizer,
        args.epochs,
        len(train_loader),
        eta_min=cfg.cosine_lr_min,
        warmup=cfg.warmup_epochs) if cfg.use_cosine_lr else None

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion_train, optimizer, epoch,
              scheduler, lr_milestones, cfg.warmup_epochs,
              cfg.dataaug.mixup_rate, cfg.dataaug.labelsmoothing_rate)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion_val)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save_dir)
def main():
    global best_prec1, args

    args = parse()

    cudnn.benchmark = True
    best_prec1 = 0
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.local_rank)
        torch.set_printoptions(precision=10)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.log_dir = args.log_dir + '_' + time.asctime(time.localtime(time.time())).replace(" ", "-")
    os.makedirs('results/{}'.format(args.log_dir), exist_ok=True)
    global logger
    logger = create_logger('global_logger', "results/{}/log.txt".format(args.log_dir))
    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        logger.info(args.local_rank)
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
    logger.info(args.world_size)
    if args.local_rank == 0:

        wandb.init(project="tinyimagenet", dir="results/{}".format(args.log_dir),
           name=args.log_dir,)
        wandb.config.update(args)


        logger.info("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))




    args.batch_size = int(args.batch_size/args.world_size)
    logger.info(args.batch_size)

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    if args.channels_last:
        memory_format = torch.channels_last
    else:
        memory_format = torch.contiguous_format

    # create model

    import models
    logger.info('==> Building model..')
    global norm_layer
    print(args.norm_layer)
    if args.norm_layer is not None and args.norm_layer != 'False':
        if args.norm_layer == 'cbn':
            norm_layer =  models.__dict__['Constraint_Norm2d']
        elif args.norm_layer == 'cbn_mu_v1':
            norm_layer = models.__dict__['Constraint_Norm_mu_v1_2d']
        else:
            norm_layer = None

    if args.pretrained:
        logger.info("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True, norm_layer=norm_layer)
    else:
        logger.info("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](norm_layer=norm_layer)

    model = model

    model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr*float(args.batch_size*args.world_size)/256.
    constraint_param = []
    for m in model.modules():
        if isinstance(m, Constraint_Lagrangian):
            m.weight_decay = args.constraint_decay
            m.get_optimal_lagrangian = args.get_optimal_lagrangian
            constraint_param.extend(list(map(id, m.parameters())))


    if args.decrease_affine_lr == 1:
        origin_param = filter(lambda p:id(p) not in constraint_param, model.parameters())

        optimizer = optim.SGD([
                        {'params': origin_param},
                        {'params':  filter(lambda p:id(p) in constraint_param, model.parameters()),
                                'lr': args.constraint_lr,
                                'weight_decay': args.constraint_decay},
                        ],
                        lr=args.lr, momentum=0.9,
                        weight_decay=args.decay)



    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    if args.mixed_precision:
        model, optimizer = amp.initialize(model, optimizer,
                                        opt_level=args.opt_level,
                                        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                        loss_scale=args.loss_scale
                                        )

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    print(args.resume)
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                logger.info("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                logger.info("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))
            else:
                logger.info("=> no checkpoint found at '{}'".format(args.resume))
        resume()

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if(args.arch == "inception_v3"):
        raise RuntimeError("Currently, inception_v3 is not supported by this example.")
        # crop_size = 299
        # val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            # transforms.ToTensor(), Too slow
            # normalize,
        ]))
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
        ]))

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

    collate_fn = lambda b: fast_collate(b, memory_format)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True,
        sampler=val_sampler,
        collate_fn=collate_fn)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    #initialization
    device = torch.device("cuda")

    for m in model.modules():
        if isinstance(m, norm_layer):
            m.sample_noise = args.sample_noise
            m.sample_mean = torch.zeros(m.num_features).to(device)
            m.add_noise = args.add_noise
            m.sample_mean_std = torch.sqrt(torch.Tensor([args.noise_mean_std])[0].to(device))
            m.sample_var_std = torch.sqrt(torch.Tensor([args.noise_var_std])[0].to(device))

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        if args.warmup_noise is not None:
            if epoch in args.warmup_noise:

                for m in model.modules():
                    if isinstance(m, norm_layer):
                        m.sample_mean_std *= math.sqrt(args.warmup_scale)
                        m.sample_var_std *= math.sqrt(args.warmup_scale)


        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(epoch, val_loader, model, criterion)


        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, filename = os.path.join("results/" + args.log_dir, "{}_checkpoint.pth.tar".format(epoch)))
Пример #3
0
class RunManager:
    def __init__(self, path, net, run_config: RunConfig, out_log=True):
        self.path = path
        self.net = net
        self.run_config = run_config
        self.out_log = out_log

        self._logs_path, self._save_path = None, None
        self.best_acc = 0
        self.start_epoch = 0
        gpu = self.run_config.local_rank
        torch.cuda.set_device(gpu)

        # initialize model (default)
        self.net.init_model(run_config.model_init, run_config.init_div_groups)

        # net info
        self.net = self.net.cuda()
        if run_config.local_rank == 0:
            self.print_net_info()

        if self.run_config.sync_bn:
            self.net = apex.parallel.convert_syncbn_model(self.net)
        print('local_rank: %d' % self.run_config.local_rank)

        self.run_config.init_lr = self.run_config.init_lr * float(
            self.run_config.train_batch_size *
            self.run_config.world_size) / 256.
        self.criterion = nn.CrossEntropyLoss()
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            self.optimizer = self.run_config.build_optimizer([
                self.net.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ])
        else:
            self.optimizer = self.run_config.build_optimizer(
                self.net.weight_parameters())
        # self.net, self.optimizer = amp.initialize(self.net, self.optimizer, opt_level='O1')
        self.net = DDP(self.net, delay_allreduce=True)
        cudnn.benchmark = True

    """ save path and log path """

    @property
    def save_path(self):
        if self._save_path is None:
            save_path = os.path.join(self.path, 'checkpoint')
            os.makedirs(save_path, exist_ok=True)
            self._save_path = save_path
        return self._save_path

    @property
    def logs_path(self):
        if self._logs_path is None:
            logs_path = os.path.join(self.path, 'logs')
            os.makedirs(logs_path, exist_ok=True)
            self._logs_path = logs_path
        return self._logs_path

    """ net info """

    def reset_model(self, model, model_origin=None):
        self.net = model
        self.net.init_model(self.run_config.model_init,
                            self.run_config.init_div_groups)
        if model_origin != None:
            if self.run_config.local_rank == 0:
                print('-' * 30 + ' start pruning ' + '-' * 30)
            get_unpruned_weights(self.net, model_origin)
            if self.run_config.local_rank == 0:
                print('-' * 30 + ' end pruning ' + '-' * 30)
        # net info
        self.net = self.net.cuda()
        if self.run_config.local_rank == 0:
            self.print_net_info()

        if self.run_config.sync_bn:
            self.net = apex.parallel.convert_syncbn_model(self.net)
        print('local_rank: %d' % self.run_config.local_rank)

        self.criterion = nn.CrossEntropyLoss()
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            self.optimizer = self.run_config.build_optimizer([
                self.net.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ])
        else:
            self.optimizer = self.run_config.build_optimizer(
                self.net.weight_parameters())
        # model, self.optimizer = amp.initialize(model, self.optimizer,
        #                                        opt_level='O2',
        #                                        keep_batchnorm_fp32=True,
        #                                        loss_scale=1.0
        #                                        )
        self.net = DDP(self.net, delay_allreduce=True)
        cudnn.benchmark = True
        # if model_origin!=None:
        #     if self.run_config.local_rank==0:
        #         print('-'*30+' start training bn '+'-'*30)
        #     self.train_bn(1)
        #     if self.run_config.local_rank==0:
        #         print('-'*30+' end training bn '+'-'*30)

    # noinspection PyUnresolvedReferences
    def net_flops(self):
        data_shape = [1] + list(self.run_config.data_provider.data_shape)

        net = self.net
        input_var = torch.zeros(data_shape).cuda()
        with torch.no_grad():
            flops = profile_macs(net, input_var)
        return flops

    def print_net_info(self):
        # parameters
        total_params = count_parameters(self.net)
        if self.out_log:
            print('Total training params: %.2fM' % (total_params / 1e6))
        net_info = {
            'param': '%.2fM' % (total_params / 1e6),
        }

        # flops
        flops = self.net_flops()
        if self.out_log:
            print('Total FLOPs: %.1fM' % (flops / 1e6))
        net_info['flops'] = '%.1fM' % (flops / 1e6)

        # config
        if self.out_log:
            print('Net config: ' + str(self.net.config))
        net_info['config'] = str(self.net.config)

        with open('%s/net_info.txt' % self.logs_path, 'w') as fout:
            fout.write(json.dumps(net_info, indent=4) + '\n')

    """ save and load models """

    def save_model(self, checkpoint=None, is_best=False, model_name=None):
        if checkpoint is None:
            checkpoint = {'state_dict': self.net.module.state_dict()}

        if model_name is None:
            model_name = 'checkpoint.pth.tar'

        checkpoint[
            'dataset'] = self.run_config.dataset  # add `dataset` info to the checkpoint
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        model_path = os.path.join(self.save_path, model_name)
        with open(latest_fname, 'w') as fout:
            fout.write(model_path + '\n')
        torch.save(checkpoint, model_path)

        if is_best:
            best_path = os.path.join(self.save_path, 'model_best.pth.tar')
            torch.save({'state_dict': checkpoint['state_dict']}, best_path)

    def load_model(self, model_fname=None):
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        if model_fname is None and os.path.exists(latest_fname):
            with open(latest_fname, 'r') as fin:
                model_fname = fin.readline()
                if model_fname[-1] == '\n':
                    model_fname = model_fname[:-1]
        # noinspection PyBroadException
        try:
            if model_fname is None or not os.path.exists(model_fname):
                model_fname = '%s/checkpoint.pth.tar' % self.save_path
                with open(latest_fname, 'w') as fout:
                    fout.write(model_fname + '\n')
            if self.out_log:
                print("=> loading checkpoint '{}'".format(model_fname))

            if torch.cuda.is_available():
                checkpoint = torch.load(model_fname)
            else:
                checkpoint = torch.load(model_fname, map_location='cpu')

            self.net.module.load_state_dict(checkpoint['state_dict'])
            # set new manual seed
            new_manual_seed = int(time.time())
            torch.manual_seed(new_manual_seed)
            torch.cuda.manual_seed_all(new_manual_seed)
            np.random.seed(new_manual_seed)

            if 'epoch' in checkpoint:
                self.start_epoch = checkpoint['epoch'] + 1
            if 'best_acc' in checkpoint:
                self.best_acc = checkpoint['best_acc']
            if 'optimizer' in checkpoint:
                self.optimizer.load_state_dict(checkpoint['optimizer'])

            if self.out_log:
                print("=> loaded checkpoint '{}'".format(model_fname))
        except Exception:
            if self.out_log:
                print('fail to load checkpoint from %s' % self.save_path)

    def save_config(self, print_info=True):
        """ dump run_config and net_config to the model_folder """
        os.makedirs(self.path, exist_ok=True)
        net_save_path = os.path.join(self.path, 'net.config')
        json.dump(self.net.module.config, open(net_save_path, 'w'), indent=4)
        if print_info:
            print('Network configs dump to %s' % net_save_path)

        run_save_path = os.path.join(self.path, 'run.config')
        json.dump(self.run_config.config, open(run_save_path, 'w'), indent=4)
        if print_info:
            print('Run configs dump to %s' % run_save_path)

    """ train and test """

    def write_log(self, log_str, prefix, should_print=True):
        """ prefix: valid, train, test """
        if prefix in ['valid', 'test']:
            with open(os.path.join(self.logs_path, 'valid_console.txt'),
                      'a') as fout:
                fout.write(log_str + '\n')
                fout.flush()
        if prefix in ['valid', 'test', 'train']:
            with open(os.path.join(self.logs_path, 'train_console.txt'),
                      'a') as fout:
                if prefix in ['valid', 'test']:
                    fout.write('=' * 10)
                fout.write(log_str + '\n')
                fout.flush()
        if prefix in ['prune']:
            with open(os.path.join(self.logs_path, 'prune_console.txt'),
                      'a') as fout:
                if prefix in ['valid', 'test']:
                    fout.write('=' * 10)
                fout.write(log_str + '\n')
                fout.flush()
        if should_print:
            print(log_str)

    def validate(self,
                 is_test=True,
                 net=None,
                 use_train_mode=False,
                 return_top5=False):
        if is_test:
            data_loader = self.run_config.test_loader
        else:
            data_loader = self.run_config.valid_loader

        if net is None:
            net = self.net

        if use_train_mode:
            net.train()
        else:
            net.eval()
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        # noinspection PyUnresolvedReferences
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                    non_blocking=True)
                # images, labels = data[0].cuda(), data[1].cuda()
                # compute output
                output = net(images)
                loss = self.criterion(output, labels)
                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                reduced_loss = self.reduce_tensor(loss.data)
                acc1 = self.reduce_tensor(acc1)
                acc5 = self.reduce_tensor(acc5)
                losses.update(reduced_loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.run_config.print_frequency == 0 or i + 1 == len(
                        data_loader):
                    if is_test:
                        prefix = 'Test'
                    else:
                        prefix = 'Valid'
                    test_log = prefix + ': [{0}/{1}]\t' \
                                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \
                        format(i, len(data_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
                    if return_top5:
                        test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                            top5=top5)
                    print(test_log)
        self.run_config.valid_loader.reset()
        self.run_config.test_loader.reset()
        if return_top5:
            return losses.avg, top1.avg, top5.avg
        else:
            return losses.avg, top1.avg

    def train_bn(self, epochs=1):
        if self.run_config.local_rank == 0:
            print('training bn')
        for m in self.net.modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                m.running_mean = torch.zeros_like(m.running_mean)
                m.running_var = torch.ones_like(m.running_var)
        self.net.train()
        for i in range(epochs):
            for _, data in enumerate(self.run_config.train_loader):
                images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                    non_blocking=True)
                output = self.net(images)
                del output, images, labels
        if self.run_config.local_rank == 0:
            print('training bn finished')

    def train_one_epoch(self, adjust_lr_func, train_log_func, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        # switch to train mode
        self.net.train()

        end = time.time()
        for i, data in enumerate(self.run_config.train_loader):
            data_time.update(time.time() - end)
            new_lr = adjust_lr_func(i)
            images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                non_blocking=True)
            # compute output
            output = self.net(images)
            if self.run_config.label_smoothing > 0:
                loss = cross_entropy_with_label_smoothing(
                    output, labels, self.run_config.label_smoothing)
            else:
                loss = self.criterion(output, labels)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            reduced_loss = self.reduce_tensor(loss.data)
            acc1 = self.reduce_tensor(acc1)
            acc5 = self.reduce_tensor(acc5)
            losses.update(reduced_loss, images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # compute gradient and do SGD step
            self.net.zero_grad()  # or self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            torch.cuda.synchronize()
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i % self.run_config.print_frequency == 0
                    or i + 1 == len(self.run_config.train_loader)
                ) and self.run_config.local_rank == 0:
                batch_log = train_log_func(i, batch_time, data_time, losses,
                                           top1, top5, new_lr)
                self.write_log(batch_log, 'train')
        return top1, top5

    def train(self, print_top5=False):
        def train_log_func(epoch_, i, batch_time, data_time, losses, top1,
                           top5, lr):
            batch_log = 'Train [{0}][{1}/{2}]\t' \
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                        'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \
                format(epoch_ + 1, i, len(self.run_config.train_loader) - 1,
                       batch_time=batch_time, data_time=data_time, losses=losses, top1=top1)
            if print_top5:
                batch_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                    top5=top5)
            batch_log += '\tlr {lr:.5f}'.format(lr=lr)
            return batch_log

        for epoch in range(self.start_epoch, self.run_config.n_epochs):
            if self.run_config.local_rank == 0:
                print('\n', '-' * 30, 'Train epoch: %d' % (epoch + 1),
                      '-' * 30, '\n')

            end = time.time()
            train_top1, train_top5 = self.train_one_epoch(
                lambda i: self.run_config.adjust_learning_rate(
                    self.optimizer, epoch, i, len(self.run_config.train_loader)
                ), lambda i, batch_time, data_time, losses, top1, top5, new_lr:
                train_log_func(epoch, i, batch_time, data_time, losses, top1,
                               top5, new_lr), epoch)
            time_per_epoch = time.time() - end
            seconds_left = int(
                (self.run_config.n_epochs - epoch - 1) * time_per_epoch)
            if self.run_config.local_rank == 0:
                print('Time per epoch: %s, Est. complete in: %s' %
                      (str(timedelta(seconds=time_per_epoch)),
                       str(timedelta(seconds=seconds_left))))

            if (epoch + 1) % self.run_config.validation_frequency == 0:
                val_loss, val_acc, val_acc5 = self.validate(is_test=False,
                                                            return_top5=True)
                is_best = val_acc > self.best_acc
                self.best_acc = max(self.best_acc, val_acc)
                val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f} ({4:.3f})'. \
                    format(epoch + 1, self.run_config.n_epochs, val_loss, val_acc, self.best_acc)
                if print_top5:
                    val_log += '\ttop-5 acc {0:.3f}\tTrain top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
                        format(val_acc5, top1=train_top1, top5=train_top5)
                else:
                    val_log += '\tTrain top-1 {top1.avg:.3f}'.format(
                        top1=train_top1)
                if self.run_config.local_rank == 0:
                    self.write_log(val_log, 'valid')
            else:
                is_best = False
            if self.run_config.local_rank == 0:
                self.save_model(
                    {
                        'epoch': epoch,
                        'best_acc': self.best_acc,
                        'optimizer': self.optimizer.state_dict(),
                        'state_dict': self.net.state_dict(),
                    },
                    is_best=is_best)
            self.run_config.train_loader.reset()
            self.run_config.valid_loader.reset()
            self.run_config.test_loader.reset()

    def reduce_tensor(self, tensor):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= self.run_config.world_size
        return rt
Пример #4
0
def main():
    import os

    args, args_text = _parse_args()

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = 'train'
        if args.gate_train:
            exp_name += '-dynamic'
        if args.slim_train:
            exp_name += '-slimmable'
        exp_name += '-{}'.format(args.model)
        exp_info = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, exp_name, exp_info)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
    setup_default_logging(outdir=output_dir, local_rank=args.local_rank)

    torch.backends.cudnn.benchmark = True

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        # torch.distributed.init_process_group(backend='nccl',
        #                                      init_method='tcp://127.0.0.1:23334',
        #                                      rank=args.local_rank,
        #                                      world_size=int(os.environ['WORLD_SIZE']))
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    # --------- random seed -----------
    random.seed(args.seed)  # TODO: do we need same seed on all GPU?
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         drop_path_rate=args.drop_path,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    if args.train_mode == 'se':
        optimizer = create_optimizer(args, model.get_se())
    elif args.train_mode == 'bn':
        optimizer = create_optimizer(args, model.get_bn())
    elif args.train_mode == 'all':
        optimizer = create_optimizer(args, model)
    elif args.train_mode == 'gate':
        optimizer = create_optimizer(args, model.get_gate())

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    if resume_state and not args.no_resume_opt:
        # ----------- Load Optimizer ---------
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            assert not args.split_bn
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank],
                        find_unused_parameters=True
                        )  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    # ------------- data --------------
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        logging.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)
    collate_fn = None
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )
    loader_bn = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=args.train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
    )

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            logging.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)
    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # ------------- loss_fn --------------
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits,
                                        smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn
    if args.ieb:
        distill_loss_fn = SoftTargetCrossEntropy().cuda()
    else:
        distill_loss_fn = None

    if args.local_rank == 0:
        model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=True)
    else:
        model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=False)

    if not args.test_mode:
        # start training
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)
            train_metrics = OrderedDict([('loss', 0.)])
            # train
            if args.gate_train:
                train_metrics = train_epoch_slim_gate(
                    epoch,
                    model,
                    loader_train,
                    optimizer,
                    train_loss_fn,
                    args,
                    lr_scheduler=lr_scheduler,
                    saver=saver,
                    output_dir=output_dir,
                    use_amp=use_amp,
                    model_ema=model_ema,
                    optimizer_step=args.optimizer_step)
            else:
                train_metrics = train_epoch_slim(
                    epoch,
                    model,
                    loader_train,
                    optimizer,
                    loss_fn=train_loss_fn,
                    distill_loss_fn=distill_loss_fn,
                    args=args,
                    lr_scheduler=lr_scheduler,
                    saver=saver,
                    output_dir=output_dir,
                    use_amp=use_amp,
                    model_ema=model_ema,
                    optimizer_step=args.optimizer_step,
                )
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            # eval
            if args.gate_train:
                eval_sample_list = ['dynamic']
            else:
                if epoch % 10 == 0 and epoch != 0:
                    eval_sample_list = ['smallest', 'largest', 'uniform']
                else:
                    eval_sample_list = ['smallest', 'largest']

            eval_metrics = [
                validate_slim(model,
                              loader_eval,
                              validate_loss_fn,
                              args,
                              model_mode=model_mode)
                for model_mode in eval_sample_list
            ]

            if model_ema is not None and not args.model_ema_force_cpu:

                ema_eval_metrics = [
                    validate_slim(model_ema.ema,
                                  loader_eval,
                                  validate_loss_fn,
                                  args,
                                  model_mode=model_mode)
                    for model_mode in eval_sample_list
                ]

                eval_metrics = ema_eval_metrics

            if isinstance(eval_metrics, list):
                eval_metrics = eval_metrics[0]

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            # save
            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric,
                    use_amp=use_amp)
        # end training
        if best_metric is not None:
            logging.info('*** Best metric: {0} (epoch {1})'.format(
                best_metric, best_epoch))

    # test
    eval_metrics = []
    for choice in range(args.num_choice):
        # reset bn if not smallest or largest
        if choice != 0 and choice != args.num_choice - 1:
            for layer in model.modules():
                if isinstance(layer, nn.BatchNorm2d) or \
                        isinstance(layer, nn.SyncBatchNorm) or \
                        (has_apex and isinstance(layer, apex.parallel.SyncBatchNorm)):
                    layer.reset_running_stats()
            model.train()
            with torch.no_grad():
                for batch_idx, (input, target) in enumerate(loader_bn):
                    if args.slim_train:
                        if hasattr(model, 'module'):
                            model.module.set_mode('uniform', choice=choice)
                        else:
                            model.set_mode('uniform', choice=choice)
                        model(input)

                    if batch_idx % 1000 == 0 and batch_idx != 0:
                        print('Subnet {} : reset bn for {} steps'.format(
                            choice, batch_idx))
                        break
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

        eval_metrics.append(
            validate_slim(model,
                          loader_eval,
                          validate_loss_fn,
                          args,
                          model_mode=choice))
    if args.local_rank == 0:
        print('Test results of the last epoch:\n', eval_metrics)
Пример #5
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--meta_path",
        default=None,
        type=str,
        required=False,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--classifier',
        default='guoday',
        type=str,
        required=True,
        help='classifier type, guoday or MLP or GRU_MLP or ...')
    parser.add_argument('--optimizer',
                        default='RAdam',
                        type=str,
                        required=True,
                        help='optimizer we use, RAdam or ...')
    parser.add_argument("--do_label_smoothing",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to do label smoothing. yes or no.")
    parser.add_argument('--draw_loss_steps',
                        default=1,
                        type=int,
                        required=True,
                        help='training steps to draw loss')
    parser.add_argument('--label_name',
                        default='label',
                        type=str,
                        required=True,
                        help='label name in original train set index')

    ## Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to run training. yes or no.")
    parser.add_argument("--do_test",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to run training. yes or no.")
    parser.add_argument("--do_eval",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to run eval on the dev set. yes or no.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--eval_steps", default=200, type=int, help="")
    parser.add_argument("--lstm_hidden_size", default=300, type=int, help="")
    parser.add_argument("--lstm_layers", default=2, type=int, help="")
    parser.add_argument("--dropout", default=0.5, type=float, help="")

    parser.add_argument("--train_steps", default=-1, type=int, help="")
    parser.add_argument("--report_steps", default=-1, type=int, help="")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--split_num", default=3, type=int, help="text split")
    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    args = parser.parse_args()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    try:
        os.makedirs(args.output_dir)
    except:
        pass

    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path,
                                              do_lower_case=args.do_lower_case)

    # tensorboard_log_dir = args.output_dir

    # loss_now = tf.placeholder(dtype=tf.float32, name='loss_now')
    # loss_mean = tf.placeholder(dtype=tf.float32, name='loss_mean')
    # loss_now_variable = loss_now
    # loss_mean_variable = loss_mean
    # train_loss = tf.summary.scalar('train_loss', loss_now_variable)
    # dev_loss_mean = tf.summary.scalar('dev_loss_mean', loss_mean_variable)
    # merged = tf.summary.merge([train_loss, dev_loss_mean])

    config = BertConfig.from_pretrained(args.model_name_or_path, num_labels=3)
    config.hidden_dropout_prob = args.dropout

    # Prepare model
    if args.do_train == 'yes':
        model = BertForSequenceClassification.from_pretrained(
            args.model_name_or_path, args, config=config)

        if args.fp16:
            model.half()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    if args.do_train == 'yes':
        print(
            '________________________now training______________________________'
        )
        # Prepare data loader

        train_examples = read_examples(os.path.join(args.data_dir,
                                                    'train.csv'),
                                       is_training=True,
                                       label_name=args.label_name)
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      args.max_seq_length,
                                                      args.split_num, True)
        # print('train_feature_size=', train_features.__sizeof__())
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features],
                                 dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label)
        # print('train_data=',train_data[0])
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps)

        num_train_optimization_steps = args.train_steps

        # Prepare optimizer

        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if args.optimizer == 'RAdam':
            optimizer = RAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate)
        else:
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=args.train_steps)

        global_step = 0

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        best_acc = 0
        model.train()
        tr_loss = 0
        loss_batch = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        bar = tqdm(range(num_train_optimization_steps),
                   total=num_train_optimization_steps)
        train_dataloader = cycle(train_dataloader)

        # with tf.Session() as sess:
        #     summary_writer = tf.summary.FileWriter(tensorboard_log_dir, sess.graph)
        #     sess.run(tf.global_variables_initializer())

        list_loss_mean = []
        bx = []
        eval_F1 = []
        ax = []

        for step in bar:
            batch = next(train_dataloader)
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            loss = model(input_ids=input_ids,
                         token_type_ids=segment_ids,
                         attention_mask=input_mask,
                         labels=label_ids)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            tr_loss += loss.item()
            loss_batch += loss.item()
            train_loss = round(
                tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1),
                4)

            bar.set_description("loss {}".format(train_loss))
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            if args.fp16:
                # optimizer.backward(loss)
                loss.backward()
            else:
                loss.backward()

            # draw loss every n docs
            if (step + 1) % int(args.draw_loss_steps /
                                (args.train_batch_size /
                                 args.gradient_accumulation_steps)) == 0:
                list_loss_mean.append(round(loss_batch, 4))
                bx.append(step + 1)
                plt.plot(bx,
                         list_loss_mean,
                         label='loss_mean',
                         linewidth=1,
                         color='b',
                         marker='o',
                         markerfacecolor='green',
                         markersize=2)
                plt.savefig(args.output_dir + '/labeled.jpg')
                loss_batch = 0

            # paras update every batch data.
            if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = args.learning_rate * warmup_linear.get_lr(
                        global_step, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            # report results every 200 real batch.
            if step % (args.eval_steps *
                       args.gradient_accumulation_steps) == 0 and step > 0:
                tr_loss = 0
                nb_tr_examples, nb_tr_steps = 0, 0
                logger.info("***** Report result *****")
                logger.info("  %s = %s", 'global_step', str(global_step))
                logger.info("  %s = %s", 'train loss', str(train_loss))

            # do evaluation totally 10 times during training stage.
            if args.do_eval == 'yes' and (step + 1) % int(
                    num_train_optimization_steps / 10) == 0 and step > 450:
                for file in ['dev.csv']:
                    inference_labels = []
                    gold_labels = []
                    inference_logits = []
                    eval_examples = read_examples(os.path.join(
                        args.data_dir, file),
                                                  is_training=True,
                                                  label_name=args.label_name)
                    eval_features = convert_examples_to_features(
                        eval_examples, tokenizer, args.max_seq_length,
                        args.split_num, False)
                    all_input_ids = torch.tensor(select_field(
                        eval_features, 'input_ids'),
                                                 dtype=torch.long)
                    all_input_mask = torch.tensor(select_field(
                        eval_features, 'input_mask'),
                                                  dtype=torch.long)
                    all_segment_ids = torch.tensor(select_field(
                        eval_features, 'segment_ids'),
                                                   dtype=torch.long)
                    all_label = torch.tensor([f.label for f in eval_features],
                                             dtype=torch.long)

                    eval_data = TensorDataset(all_input_ids, all_input_mask,
                                              all_segment_ids, all_label)

                    logger.info("***** Running evaluation *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    # Run prediction for full data
                    eval_sampler = SequentialSampler(eval_data)
                    eval_dataloader = DataLoader(
                        eval_data,
                        sampler=eval_sampler,
                        batch_size=args.eval_batch_size)

                    model.eval()
                    eval_loss, eval_accuracy = 0, 0
                    nb_eval_steps, nb_eval_examples = 0, 0
                    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)

                        with torch.no_grad():
                            tmp_eval_loss = model(input_ids=input_ids,
                                                  token_type_ids=segment_ids,
                                                  attention_mask=input_mask,
                                                  labels=label_ids)
                            logits = model(input_ids=input_ids,
                                           token_type_ids=segment_ids,
                                           attention_mask=input_mask)

                        logits = logits.detach().cpu().numpy()
                        label_ids = label_ids.to('cpu').numpy()
                        inference_labels.append(np.argmax(logits, axis=1))
                        gold_labels.append(label_ids)
                        inference_logits.append(logits)
                        eval_loss += tmp_eval_loss.mean().item()
                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                    gold_labels = np.concatenate(gold_labels, 0)
                    inference_labels = np.concatenate(inference_labels, 0)
                    inference_logits = np.concatenate(inference_logits, 0)
                    model.train()
                    ###############################################
                    num_gold_0 = np.sum(gold_labels == 0)
                    num_gold_1 = np.sum(gold_labels == 1)
                    num_gold_2 = np.sum(gold_labels == 2)

                    right_0 = 0
                    right_1 = 0
                    right_2 = 0
                    error_0 = 0
                    error_1 = 0
                    error_2 = 0

                    for gold_label, inference_label in zip(
                            gold_labels, inference_labels):
                        if gold_label == inference_label:
                            if gold_label == 0:
                                right_0 += 1
                            elif gold_label == 1:
                                right_1 += 1
                            else:
                                right_2 += 1
                        elif inference_label == 0:
                            error_0 += 1
                        elif inference_label == 1:
                            error_1 += 1
                        else:
                            error_2 += 1

                    recall_0 = right_0 / (num_gold_0 + 1e-5)
                    recall_1 = right_1 / (num_gold_1 + 1e-5)
                    recall_2 = right_2 / (num_gold_2 + 1e-5)
                    precision_0 = right_0 / (error_0 + right_0 + 1e-5)
                    precision_1 = right_1 / (error_1 + right_1 + 1e-5)
                    precision_2 = right_2 / (error_2 + right_2 + 1e-5)
                    f10 = 2 * precision_0 * recall_0 / (precision_0 +
                                                        recall_0 + 1e-5)
                    f11 = 2 * precision_1 * recall_1 / (precision_1 +
                                                        recall_1 + 1e-5)
                    f12 = 2 * precision_2 * recall_2 / (precision_2 +
                                                        recall_2 + 1e-5)

                    output_dev_result_file = os.path.join(
                        args.output_dir, "dev_results.txt")
                    with open(output_dev_result_file, 'a',
                              encoding='utf-8') as f:
                        f.write('precision:' + str(precision_0) + ' ' +
                                str(precision_1) + ' ' + str(precision_2) +
                                '\n')
                        f.write('recall:' + str(recall_0) + ' ' +
                                str(recall_1) + ' ' + str(recall_2) + '\n')
                        f.write('f1:' + str(f10) + ' ' + str(f11) + ' ' +
                                str(f12) + '\n' + '\n')

                    eval_loss = eval_loss / nb_eval_steps
                    eval_accuracy = accuracy(inference_logits, gold_labels)
                    # draw loss.
                    eval_F1.append(round(eval_accuracy, 4))
                    ax.append(step)
                    plt.plot(ax,
                             eval_F1,
                             label='eval_F1',
                             linewidth=1,
                             color='r',
                             marker='o',
                             markerfacecolor='blue',
                             markersize=2)
                    for a, b in zip(ax, eval_F1):
                        plt.text(a, b, b, ha='center', va='bottom', fontsize=8)
                    plt.savefig(args.output_dir + '/labeled.jpg')

                    result = {
                        'eval_loss': eval_loss,
                        'eval_F1': eval_accuracy,
                        'global_step': global_step,
                        'loss': train_loss
                    }

                    output_eval_file = os.path.join(args.output_dir,
                                                    "eval_results.txt")
                    with open(output_eval_file, "a") as writer:
                        for key in sorted(result.keys()):
                            logger.info("  %s = %s", key, str(result[key]))
                            writer.write("%s = %s\n" % (key, str(result[key])))
                        writer.write('*' * 80)
                        writer.write('\n')
                    if eval_accuracy > best_acc and 'dev' in file:
                        print("=" * 80)
                        print("more accurate model arises, now best F1 = ",
                              eval_accuracy)
                        print("Saving Model......")
                        best_acc = eval_accuracy
                        # Save a trained model, only save the model it-self
                        model_to_save = model.module if hasattr(
                            model, 'module') else model
                        output_model_file = os.path.join(
                            args.output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        print("=" * 80)
                    '''
                    if (step+1) / int(num_train_optimization_steps/10) > 9.5:
                        print("=" * 80)
                        print("End of training. Saving Model......")
                        # Save a trained model, only save the model it-self
                        model_to_save = model.module if hasattr(model, 'module') else model
                        output_model_file = os.path.join(args.output_dir, "pytorch_model_final_step.bin")
                        torch.save(model_to_save.state_dict(), output_model_file)
                        print("=" * 80)
                    '''

    if args.do_test == 'yes':
        start_time = time.time()
        print(
            '___________________now testing for best eval f1 model_________________________'
        )
        try:
            del model
        except:
            pass
        gc.collect()
        args.do_train = 'no'
        model = BertForSequenceClassification.from_pretrained(os.path.join(
            args.output_dir, "pytorch_model.bin"),
                                                              args,
                                                              config=config)
        model.half()
        for layer in model.modules():
            if isinstance(layer, torch.nn.modules.batchnorm._BatchNorm):
                layer.float()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from "
                    "https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        for file, flag in [('test.csv', 'test')]:
            inference_labels = []
            gold_labels = []
            eval_examples = read_examples(os.path.join(args.data_dir, file),
                                          is_training=False,
                                          label_name=args.label_name)
            eval_features = convert_examples_to_features(
                eval_examples, tokenizer, args.max_seq_length, args.split_num,
                False)
            all_input_ids = torch.tensor(select_field(eval_features,
                                                      'input_ids'),
                                         dtype=torch.long)
            all_input_mask = torch.tensor(select_field(eval_features,
                                                       'input_mask'),
                                          dtype=torch.long)
            all_segment_ids = torch.tensor(select_field(
                eval_features, 'segment_ids'),
                                           dtype=torch.long)
            all_label = torch.tensor([f.label for f in eval_features],
                                     dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_label)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(
                        input_ids=input_ids,
                        token_type_ids=segment_ids,
                        attention_mask=input_mask).detach().cpu().numpy()
                    # print('test_logits=', logits)
                label_ids = label_ids.to('cpu').numpy()
                inference_labels.append(logits)
                gold_labels.append(label_ids)
            gold_labels = np.concatenate(gold_labels, 0)
            logits = np.concatenate(inference_labels, 0)
            if flag == 'dev':
                print(flag, accuracy(logits, gold_labels))
            elif flag == 'test':
                df = pd.read_csv(os.path.join(args.data_dir, file))
                df['label_0'] = logits[:, 0]
                df['label_1'] = logits[:, 1]
                df['label_2'] = logits[:, 2]
                df[['id', 'label_0', 'label_1',
                    'label_2']].to_csv(os.path.join(args.output_dir,
                                                    "sub.csv"),
                                       index=False)
                # df[['id', 'label_0', 'label_1']].to_csv(os.path.join(args.output_dir, "sub.csv"), index=False)
            else:
                raise ValueError('flag not in [dev, test]')
        print('inference time usd = {}s'.format(time.time() - start_time))
        '''