コード例 #1
0
def validate(val_loader, model, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for inputs, labels in val_loader:
            data_time.update(time.time() - end)
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))
            batch_time.update(time.time() - end)
            end = time.time()

    throughput = 1.0 / (batch_time.avg / inputs.size(0))

    return top1.avg, top5.avg, throughput
コード例 #2
0
def validate(val_loader, model, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    result = 0
    with torch.no_grad():
        end = time.time()
        
        labels_len = 0
        for inputs, labels in val_loader:
            data_time.update(time.time() - end)
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            outputs = F.softmax(outputs, dim=1)
            acc1, acc5 = accuracy(outputs, labels, topk=(1, 2))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            # compute softmax loss
            pred_score = outputs.cuda().data.cpu().numpy()
            for i in range(len(labels)):
                for j in range(4):
                    if labels[i] == j:
                        result += math.log(pred_score[i][j])
            labels_len += len(labels)
        result = -result/labels_len
        logger.info(f'rb loss: {result}')
    throughput = 1.0 / (batch_time.avg / inputs.size(0))
    
    return top1.avg, top5.avg, throughput, result
コード例 #3
0
def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):
    top1 = AverageMeter()
    top5 = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    iters = len(train_loader.dataset) // (args.per_node_batch_size * gpus_num)
    prefetcher = DataPrefetcher(train_loader)
    inputs, labels = prefetcher.next()
    iter_index = 1
    while inputs is not None:
        inputs, labels = inputs.cuda(), labels.cuda()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss = loss / args.accumulation_steps

        if args.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if iter_index % args.accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))
        losses.update(loss.item(), inputs.size(0))

        inputs, labels = prefetcher.next()

        if local_rank == 0 and iter_index % args.print_interval == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}"
            )

        iter_index += 1

    scheduler.step()

    return top1.avg, top5.avg, losses.avg
コード例 #4
0
def main(logger, args):
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    gpus = torch.cuda.device_count()
    logger.info(f'use {gpus} gpus')
    logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()
    
    
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              pin_memory=True,
                              num_workers=args.num_workers)
    val_loader = DataLoader(Config.val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=args.num_workers)
    logger.info('finish loading data')
    
    # dataset and dataloader
    logger.info('start loading data')
    model_b3 = models.__dict__['efficientnet_b3'](**{
                    "pretrained": False,
                    "num_classes": 4,
                    })
        
    model_12GF = models.__dict__['RegNetY_12GF'](**{
                    "pretrained": False,
                    "num_classes": 4,
                    })
    '''
    model_vovnet = models.__dict__['VoVNet99_se'](**{
                    "pretrained": False,
                    "num_classes": 4,
                    })
    '''
    model_32GF = models.__dict__['RegNetY_32GF'](**{
                    "pretrained": False,
                    "num_classes": 4,
                    })
    
    for name, param in model_b3.named_parameters():
        param.requires_grad = False
        logger.info(f"{name},{param.requires_grad}")
    
    for name, param in model_12GF.named_parameters():
        param.requires_grad = False
        logger.info(f"{name},{param.requires_grad}")
    '''
    for name, param in model_vovnet.named_parameters():
        param.requires_grad = False
        logger.info(f"{name},{param.requires_grad}")
    '''
    for name, param in model_32GF.named_parameters():
        param.requires_grad = False
        logger.info(f"{name},{param.requires_grad}")
    
    # merge model
    logger.info(f"creating ensemble model")
    #model = JSTNET(model_b3, model_12GF, model_32GF, model_vovnet)
    model = JSTNET(model_b3, model_12GF, model_32GF)
    model = model.cuda()
    model_b3 = model_b3.cuda()
    model_12GF = model_12GF.cuda()
    model_32GF = model_32GF.cuda()
    #model_vovnet = model_vovnet.cuda()
    
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    
    # warm_up_with_cosine_lr
    warm_up_with_cosine_lr = lambda epoch: epoch / args.warm_up_epochs if epoch <= args.warm_up_epochs else 0.5 * (
        math.cos((epoch - args.warm_up_epochs) /
                 (args.epochs - args.warm_up_epochs) * math.pi) + 1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=warm_up_with_cosine_lr)
        
    if args.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    
    model = nn.DataParallel(model)
    
     #load model
    my_path = '/home/jns2szh/code/pytorch-ImageNet-CIFAR-COCO-VOC-training-master/imagenet_experiments/'
    logger.info(f"start load model")
    checkpoint_b3 = torch.load(my_path+'efficientnet_imagenet_DataParallel_train_example/checkpoints_b3/latest.pth', map_location=torch.device('cpu'))
    new_state_dict = OrderedDict()
    for k, v in checkpoint_b3['model_state_dict'].items():
        name = k[7:] # remove module.
        if name != 'fc.weight' and name != 'fc.bias':
            new_state_dict[name] = v
    model_b3.load_state_dict(new_state_dict)
    logger.info(f"load b3 model finished")
    
    checkpoint_12GF = torch.load(my_path+'regnet_imagenet_Dataparallel_train_example/regnet_12/latest.pth', map_location=torch.device('cpu'))
    new_state_dict = OrderedDict()
    for k, v in checkpoint_12GF['model_state_dict'].items():
        name = k[7:] # remove module.
        if name != 'fc.weight' and name != 'fc.bias':
            new_state_dict[name] = v
    model_12GF.load_state_dict(new_state_dict)
    logger.info(f"load 12GF model finished")
    
    '''
    checkpoint_vovnet = torch.load(my_path+'vovnet_Dataparallel_train_example/checkpoints/latest.pth', map_location=torch.device('cpu'))
    new_state_dict = OrderedDict()
    for k, v in checkpoint_vovnet['model_state_dict'].items():
        name = k[7:] # remove module.
        if name != 'fc.weight' and name != 'fc.bias':
            new_state_dict[name] = v
    model_vovnet.load_state_dict(new_state_dict)
    logger.info(f"load vovnet model finished")
    '''
    checkpoint_32GF = torch.load(my_path+'regnet_imagenet_Dataparallel_train_example/checkpoints/latest.pth', map_location=torch.device('cpu'))
    new_state_dict = OrderedDict()
    for k, v in checkpoint_32GF['model_state_dict'].items():
        name = k[7:] # remove module.
        if name != 'fc.weight' and name != 'fc.bias':
            new_state_dict[name] = v
    model_32GF.load_state_dict(new_state_dict)
    logger.info(f"load 32GF model finished")
    
    # resume training
    start_epoch=0
    if os.path.exists(args.resume):
        logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        logger.info(
            f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, "
            f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
            f"top1_acc: {checkpoint['acc1']}%")
    
    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)

    logger.info('start training')
    min_rb_loss = 1000
    
    for epoch in range(start_epoch, args.epochs + 1):
        #print(epoch, logger,args)
        '''
        acc1, losses = train(train_loader, model, criterion, optimizer, 
                                epoch, 
                                logger)
        '''
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        # switch to train mode
        model.train()

        iters = len(train_loader.dataset) // args.batch_size
        prefetcher = DataPrefetcher(train_loader)
        inputs, labels = prefetcher.next()
        iter_index = 1

        while inputs is not None:
            inputs, labels = inputs.cuda(), labels.cuda()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss = loss / 1

            if args.apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if iter_index % 1 == 0:
                optimizer.step()
                optimizer.zero_grad()

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, labels, topk=(1, 2))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))
            losses.update(loss.item(), inputs.size(0))

            inputs, labels = prefetcher.next()

            if iter_index % args.print_interval == 0:
                logger.info(
                    f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}"
                )

            iter_index += 1
        
        scheduler.step()
        '''
        logger.info(
            f"train: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, losses: {losses:.2f}"
        )
        '''

        acc1, acc5, throughput, rb_loss = validate(val_loader, model)
        logger.info(
            f"val: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
        )
        
        if rb_loss < min_rb_loss:
            min_rb_loss = rb_loss
            logger.info("save model")
            torch.save(
                {
                'epoch': epoch,
                'acc1': acc1,
                'loss': losses,
                'lr': scheduler.get_lr()[0],
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, os.path.join(args.checkpoints, 'latest.pth'))
        if epoch == args.epochs:
            torch.save(
                {
                'epoch': epoch,
                'acc1': acc1,
                'loss': losses,
                'lr': scheduler.get_lr()[0],
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                },
                os.path.join(
                    args.checkpoints,
                    "{}-epoch{}-acc{}.pth".format('JSTNET', epoch, acc1)))

    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, total training time: {training_time:.2f} hours")