예제 #1
0
 def get_network(self):
     args = self.args
     if args.network == 'alexnet':
         network = alexnet(128)
     if args.network == 'alexnet_cifar':
         network = AlexNet_cifar(128)
     elif args.network == 'resnet18_cifar':
         network = ResNet18_cifar(128,
                                  dropout=args.dropout,
                                  non_linear_head=args.nonlinearhead,
                                  mlpbn=True)
     elif args.network == 'resnet50_cifar':
         network = ResNet50_cifar(128, dropout=args.dropout, mlpbn=True)
     elif args.network == 'wide_resnet28':
         network = WideResNetInstance(28, 2)
     elif args.network == 'resnet18':
         network = resnet18(non_linear_head=args.nonlinearhead, mlpbn=True)
     elif args.network == 'pre-resnet18':
         network = PreActResNet18(128)
     elif args.network == 'resnet50':
         network = resnet50(non_linear_head=args.nonlinearhead, mlpbn=True)
     elif args.network == 'pre-resnet50':
         network = PreActResNet50(128)
     elif args.network == 'shufflenet':
         network = shufflenet_v2_x1_0(num_classes=128,
                                      non_linear_head=args.nonlinearhead)
     self.network = nn.DataParallel(network, device_ids=self.device_ids)
     self.network.to(self.device)
def main():

    global best_acc1
    best_acc1 = 0

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    train_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    logger = getLogger(args.save_folder)
    if args.dataset.startswith('imagenet') or args.dataset.startswith(
            'places'):
        image_size = 224
        crop_padding = 32
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=mean, std=std)
        if args.aug == 'NULL':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        elif args.aug == 'CJ':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            raise NotImplemented('augmentation not supported: {}'.format(
                args.aug))

        val_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
        if args.dataset.startswith('imagenet'):
            train_dataset = datasets.ImageFolder(train_folder, train_transform)
            val_dataset = datasets.ImageFolder(
                val_folder,
                val_transform,
            )

        if args.dataset.startswith('places'):
            train_dataset = ImageList(
                '/data/trainvalsplit_places205/train_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=train_transform,
                symbol_split=' ')
            val_dataset = ImageList(
                '/data/trainvalsplit_places205/val_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=val_transform,
                symbol_split=' ')

        print(len(train_dataset))
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.n_workers,
            pin_memory=False,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.n_workers,
                                                 pin_memory=False)
    elif args.dataset.startswith('cifar'):
        train_loader, val_loader = cifar.get_linear_dataloader(args)
    elif args.dataset.startswith('svhn'):
        train_loader, val_loader = svhn.get_linear_dataloader(args)

    # create model and optimizer
    if args.model == 'alexnet':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer, args.n_label, 'avg')
    elif args.model == 'alexnet_cifar':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet_cifar(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer,
                                             args.n_label,
                                             'avg',
                                             cifar=True)
    elif args.model == 'resnet50':
        model = resnet50(non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet18':
        model = resnet18()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet18_cifar':
        model = resnet18_cifar()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet50_cifar':
        model = resnet50_cifar()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 4)
    elif args.model == 'shufflenet':
        model = shufflenet_v2_x1_0(num_classes=128, non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg',
                                            0.5)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    print('==> loading pre-trained model')
    ckpt = torch.load(args.model_path)
    if not args.moco:
        model.load_state_dict(ckpt['state_dict'])
    else:
        try:
            state_dict = ckpt['state_dict']
            for k in list(state_dict.keys()):
                # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q'
                                ) and not k.startswith('module.encoder_q.fc'):
                    # remove prefix
                    state_dict['module.' +
                               k[len("module.encoder_q."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]
            model.load_state_dict(state_dict)
        except:
            pass
    print("==> loaded checkpoint '{}' (epoch {})".format(
        args.model_path, ckpt['epoch']))
    print('==> done')

    model = model.cuda()
    classifier = classifier.cuda()

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    if not args.adam:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.weight_decay,
                                     eps=1e-8)

    model.eval()
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            classifier.load_state_dict(checkpoint['classifier'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc1 = checkpoint['best_acc1']
            print(best_acc1.item())
            best_acc1 = best_acc1.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if 'opt' in checkpoint.keys():
                # resume optimization hyper-parameters
                print('=> resume hyper parameters')
                if 'bn' in vars(checkpoint['opt']):
                    print('using bn: ', checkpoint['opt'].bn)
                if 'adam' in vars(checkpoint['opt']):
                    print('using adam: ', checkpoint['opt'].adam)
                #args.learning_rate = checkpoint['opt'].learning_rate
                # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs
                args.lr_decay_rate = checkpoint['opt'].lr_decay_rate
                args.momentum = checkpoint['opt'].momentum
                args.weight_decay = checkpoint['opt'].weight_decay
                args.beta1 = checkpoint['opt'].beta1
                args.beta2 = checkpoint['opt'].beta2
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    tblogger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    best_acc = 0.0
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_acc5, train_loss = train(epoch, train_loader, model,
                                                  classifier, criterion,
                                                  optimizer, args)
        time2 = time.time()
        logging.info('train epoch {}, total time {:.2f}'.format(
            epoch, time2 - time1))

        logging.info(
            'Epoch: {}, lr:{} , train_loss: {:.4f}, train_acc: {:.4f}/{:.4f}'.
            format(epoch, optimizer.param_groups[0]['lr'], train_loss,
                   train_acc, train_acc5))

        tblogger.log_value('train_acc', train_acc, epoch)
        tblogger.log_value('train_acc5', train_acc5, epoch)
        tblogger.log_value('train_loss', train_loss, epoch)
        tblogger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                           epoch)

        test_acc, test_acc5, test_loss = validate(val_loader, model,
                                                  classifier, criterion, args)

        if test_acc >= best_acc:
            best_acc = test_acc

        logging.info(
            colorful(
                'Epoch: {}, val_loss: {:.4f}, val_acc: {:.4f}/{:.4f}, best_acc: {:.4f}'
                .format(epoch, test_loss, test_acc, test_acc5, best_acc)))
        tblogger.log_value('test_acc', test_acc, epoch)
        tblogger.log_value('test_acc5', test_acc5, epoch)
        tblogger.log_value('test_loss', test_loss, epoch)

        # save the best model
        if test_acc > best_acc1:
            best_acc1 = test_acc
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }
            save_name = '{}_layer{}.pth'.format(args.model, args.layer)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving best model!')
            torch.save(state, save_name)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': test_acc,
                'optimizer': optimizer.state_dict(),
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

        # tensorboard logger
        pass
예제 #3
0
def main():
    parser = argparse.ArgumentParser(description='FireClassification')
    parser.add_argument('--model', default='ghostnet')
    parser.add_argument('--train_save_path',
                        default='/home/taekwang0094/WorkSpace/FireTraining2')
    parser.add_argument('--multi_gpus', default=True)
    parser.add_argument(
        '--root', default='/home/taekwang0094/WorkSpace/Summer_Conference')
    parser.add_argument('--channel_multiplier', default=[3.0, 1.0, 1.0])
    parser.add_argument('--batch_size', default=1)
    parser.add_argument('--epoch', default=100)

    args = parser.parse_args()

    checkpoint_path = '/home/taekwang0094/WorkSpace/FireTraining2/{0}_[{1},{2},{3}]_2/model_best.pt'.format(
        args.model, args.channel_multiplier[0], args.channel_multiplier[1],
        args.channel_multiplier[2])

    checkpoint = torch.load(checkpoint_path)
    if args.channel_multiplier is False:
        channel_multiplier = [1, 1, 1]
    else:
        channel_multiplier = args.channel_multiplier
    if args.model == 'resnet18':
        pass
    elif args.model == 'mobilenet_v2':
        model = mobilenet_v2(num_classes=2)
    elif args.model == 'shufflenet_v2_x1_0':
        model = shufflenet_v2_x1_0(num_classes=2)
    elif args.model == 'mobilenetv3_small':
        model = mobilenetv3_small(num_classes=2)
    elif args.model == 'ghostnet':
        model = ghostnet(num_classes=2)

    model = torch.nn.DataParallel(model)
    model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    """
    data_transforms = {
        'train': transforms.Compose([
            # transforms.RandomResizedCrop(224),
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    """

    data_transforms = {
        'train':
        transforms.Compose([
            # transforms.RandomResizedCrop(224),
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val':
        transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    test_loader = torch.utils.data.DataLoader(
        dataset.FireDataset(args.root,
                            transforms=data_transforms['val'],
                            train='test',
                            channel_multiplier=channel_multiplier,
                            ch_preprocess=True),
        batch_size=int(args.batch_size),
        shuffle=True,
        num_workers=8,
    )

    val_total = 0
    val_correct = 0
    TP = 0
    TN = 0
    FP = 0
    FN = 0
    model.eval()
    total_inference_time = 0
    count = 0
    for batch_idx, (image, label) in enumerate(test_loader):
        count += 1
        image = image.to(device)
        label = label.to(device)
        # image, label = Variable(image).to(device), Variable(label).to(device)
        with torch.no_grad():
            time_1 = timer()
            output = model(image).to(device)
            time_2 = timer()
            inference_time = time_2 - time_1
            total_inference_time += inference_time
            _, predicted = torch.max(output.data, 1)
            #print("Predicted",batch_idx, predicted.item(),label.item() )
            val_label_eval = torch.argmax(label, 1)
            val_total += label.size(0)
            val_correct += (predicted == val_label_eval).sum().item()
            #print("pred : " ,predicted.item())
            #print("label : ", val_label_eval.item())

            # Precision, Recall, F1 Score
            if val_label_eval.item() == 1:
                if predicted.item() == 1:
                    TP += 1
                else:
                    FN += 1
            else:
                if predicted.item() == 1:
                    FP + +1
                else:
                    TN += 1

        print('{0}% 완료,\r'.format(int(batch_idx / len(test_loader) * 100)),
              end="")
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    recall = TP / (TP + FN)
    precision = TP / (TP + FP)
    f1_score = 2 * (precision * recall) / (precision + recall)
    print(
        "Model : ", args.model,
        "Channel Multiplier : [{0},{1},{2}]".format(channel_multiplier[0],
                                                    channel_multiplier[1],
                                                    channel_multiplier[2]))
    print("Average inference time : ", round(total_inference_time / count, 2),
          " Average FPS : ", round(count / total_inference_time, 2))
    print("Accuracy = ", round(accuracy, 4), "Precision = ",
          round(precision, 4), "recall = ", round(recall, 4), "f1 score = ",
          round(f1_score, 4))
예제 #4
0
파일: main_mgd.py 프로젝트: unsky/mgd
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1, best_acc5
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.arch == 'mobilenet_v1':
        from models import resnet, mobilenetv1
        t_net = resnet.resnet50(pretrained=True)
        s_net = mobilenetv1.mobilenet_v1()
        ignore_inds = []
    elif args.arch == 'mobilenet_v2':
        from models import resnet, mobilenetv2
        t_net = resnet.resnet50(pretrained=True)
        s_net = mobilenetv2.mobilenet_v2(pretrained=bool(args.use_pretrained))
        ignore_inds = [0]
    elif args.arch == 'resnet50':
        from models import resnet
        t_net = resnet.resnet152(pretrained=True)
        s_net = resnet.resnet50(pretrained=bool(args.use_pretrained))
        ignore_inds = []
    elif args.arch == 'shufflenet_v2':
        from models import resnet, shufflenetv2
        t_net = resnet.resnet50(pretrained=True)
        s_net = shufflenetv2.shufflenet_v2_x1_0(
            pretrained=bool(args.use_pretrained))
        ignore_inds = [0]
    else:
        raise ValueError

    if args.distiller == 'mgd':
        # normal and special reducers
        norm_reducers = ['amp', 'rd', 'sp']
        spec_reducers = ['sm']
        assert args.mgd_reducer in norm_reducers + spec_reducers

        # create distiller
        distiller = mgd.builder.MGDistiller if args.mgd_reducer in norm_reducers \
               else mgd.builder.SMDistiller

        d_net = distiller(t_net,
                          s_net,
                          ignore_inds=ignore_inds,
                          reducer=args.mgd_reducer,
                          sync_bn=args.sync_bn,
                          with_kd=args.mgd_with_kd,
                          preReLU=True,
                          distributed=args.distributed)
    else:
        raise NotImplementedError

    # model size
    print('the number of teacher model parameters: {}'.format(
        sum([p.data.nelement() for p in t_net.parameters()])))
    print('the number of student model parameters: {}'.format(
        sum([p.data.nelement() for p in s_net.parameters()])))
    print('the total number of model parameters: {}'.format(
        sum([p.data.nelement() for p in d_net.parameters()])))

    # dp convert
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            t_net.cuda(args.gpu)
            s_net.cuda(args.gpu)
            d_net.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            if args.sync_bn:
                s_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(s_net)
            t_net = torch.nn.parallel.DistributedDataParallel(
                t_net, find_unused_parameters=True, device_ids=[args.gpu])
            s_net = torch.nn.parallel.DistributedDataParallel(
                s_net, find_unused_parameters=True, device_ids=[args.gpu])
            d_net = torch.nn.parallel.DistributedDataParallel(
                d_net, find_unused_parameters=True, device_ids=[args.gpu])
        else:
            t_net.cuda()
            s_net.cuda()
            d_net.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            if args.sync_bn:
                s_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(s_net)
            t_net = torch.nn.parallel.DistributedDataParallel(
                t_net, find_unused_parameters=True)
            s_net = torch.nn.parallel.DistributedDataParallel(
                s_net, find_unused_parameters=True)
            d_net = torch.nn.parallel.DistributedDataParallel(
                d_net, find_unused_parameters=True)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        t_net = t_net.cuda(args.gpu)
        s_net = s_net.cuda(args.gpu)
        d_net = d_net.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        t_net = torch.nn.DataParallel(t_net).cuda()
        s_net = torch.nn.DataParallel(s_net).cuda()
        d_net = torch.nn.DataParallel(d_net).cuda()

    # define loss function (criterion), optimizer and lr_scheduler
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    model_params = list(s_net.parameters()) + list(
        d_net.module.BNs.parameters())
    optimizer = torch.optim.SGD(model_params,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    # warmup setting
    if args.warmup:
        args.epochs += args.warmup_epochs
        args.lr_drop_epochs = list(
            np.array(args.lr_drop_epochs) + args.warmup_epochs)
    lr_scheduler = build_lr_scheduler(optimizer, args)

    # optionally resume from a checkpoint
    load_checkpoint(t_net, args.teacher_resume, args)
    load_checkpoint(s_net, args.student_resume, args)

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'valid')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        validdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True)

    if args.distributed:
        extra_sampler = mgd.sampler.ExtraDistributedSampler(train_dataset)
    else:
        extra_sampler = None

    extra_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=(extra_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=extra_sampler)

    print('=> evaluate teacher model')
    validate(valid_loader, t_net, criterion, args)
    print('=> evaluate student model')
    validate(valid_loader, s_net, criterion, args)
    if args.evaluate:
        return

    if args.distiller == 'mgd':
        mgd_update(extra_loader, d_net, args)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, d_net, criterion, optimizer, lr_scheduler, epoch,
              args)

        # evaluate on validation set
        acc1, acc5 = validate(valid_loader, s_net, criterion, args)

        # update flow matrix for the next round
        if args.distiller == 'mgd' and (epoch + 1) % args.mgd_update_freq == 0:
            mgd_update(extra_loader, d_net, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = acc5 if is_best else best_acc5

        print(' * - Best - Err@1 {acc1:.3f} Err@5 {acc5:.3f}'.format(
            acc1=(100 - best_acc1), acc5=(100 - best_acc5)))

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            filename = '{}.pth'.format(args.arch)
            save_checkpoint(
                args, {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': s_net.state_dict(),
                    'best_acc1': best_acc1,
                    'best_acc5': acc5,
                    'optimizer': optimizer.state_dict(),
                }, is_best, filename)
        lr_scheduler.step()
        gc.collect()
예제 #5
0
def main():
    parser = argparse.ArgumentParser(description='FireClassification')
    parser.add_argument('--model',  default='mobilenet_v2') # resnet18, mobilenet_v2, sh
    parser.add_argument('--train_save_path', default='/home/taekwang0094/WorkSpace/FireTraining')
    parser.add_argument('--multi_gpus', default=True)
    parser.add_argument('--root', default='/home/taekwang0094/WorkSpace/Summer_Conference')
    parser.add_argument('--channel_multiplier', default=[1,1,1]) # -l 추가해서 list로 받도록 수정할것
    parser.add_argument('--batch_size', default=256)
    parser.add_argument('--epoch', default=100)

    args = parser.parse_args()
    print(args.model)

    channel_multiplier  = args.channel_multiplier

    data_transforms = {
        'train': transforms.Compose([
            #transforms.RandomResizedCrop(224),
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            #transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    train_loader = torch.utils.data.DataLoader(
        dataset.FireDataset(args.root, transforms=data_transforms['train'], channel_multiplier=channel_multiplier),
        batch_size=int(args.batch_size),
        shuffle=True,
        num_workers=8,

    )
    val_loader = torch.utils.data.DataLoader(
        dataset.FireDataset(args.root, train='val',transforms=data_transforms['val'], channel_multiplier=channel_multiplier),
        batch_size=128,
        shuffle=False,
        num_workers=8,
    )
    if args.model == 'resnet18':
        model = resnet18()
        num_ftrs = model.fc.in_features
        # model.fc = nn.Linear(num_ftrs,1)
        model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 2),
            nn.Sigmoid()
        )
    elif args.model =='mobilenet_v2':
        model = mobilenet_v2(num_classes = 2)
    elif args.model =='shufflenet_v2_x1_0':
        model = shufflenet_v2_x1_0(num_classes = 2)
    elif args.model == 'mobilenetv3_small':
        model = mobilenetv3_small(num_classes = 2)
    elif args.model == 'ghostnet':
        model = ghostnet(num_classes = 2)
    elif args.model == 'mnasnet1_0':
        model = mnasnet1_0(num_classes =2)



    print("Training Start , model : ",args.model)

    save_dir = args.model + "_[{0},{1},{2}]".format(channel_multiplier[0],channel_multiplier[1],channel_multiplier[2])
    #save_dir = os.path.join(args.model,'_[{0},{1},{2}]'.format(channel_multiplier[0],channel_multiplier[1],channel_multiplier[2]))
    save_dir_path = os.path.join(args.train_save_path,save_dir)
    if not os.path.exists(save_dir_path):
        os.makedirs(save_dir_path)


    if args.multi_gpus:
        model = torch.nn.DataParallel(model)
        model.cuda()
    else:
        model = model.to(device)

    if args.model == 'mnasnet1_0':
        optimizer = optim.SGD(model.parameters(),lr = 0.003)
    else:
        optimizer = optim.SGD(model.parameters(), lr = 0.001)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    criterion = torch.nn.BCELoss().to(device)

    best_val_loss = sys.maxsize

    file_txt = open(os.path.join(save_dir_path,'training_log.txt'),'w')


    early_stopping = EarlyStopping(patience=5, verbose=True)
    # 모델이 학습되는 동안 trainning loss를 track
    train_losses = []
    # 모델이 학습되는 동안 validation loss를 track
    valid_losses = []
    # epoch당 average training loss를 track
    avg_train_losses = []
    # epoch당 average validation loss를 track
    avg_valid_losses = []

    time_1 = 0
    time_2 = 0

    for epochs in range(args.epoch):
        model.train()

        train_correct = 0
        train_total = 0

        val_correct = 0
        val_total = 0
        time_1 = timer()
        for batch_idx, (image, label) in enumerate(train_loader):

            image = image.to(device)
            label = label.to(device)
            #image, label = Variable(image).to(device), Variable(label).to(device)
            output = model(image).to(device)
            #print(output)

            cost = criterion(output,label).to(device)
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()
            scheduler.step()


            train_losses.append(cost.item())

            #print(output)
            _, predicted = torch.max(output.data,1)
            #print("ASD",predicted)
            #print(torch.argmax(label,1))
            train_label_eval = torch.argmax(label,1)
            train_total += label.size(0)
            train_correct += (predicted == train_label_eval).sum().item()
            print('{0}% 완료,\r'.format(int(batch_idx / len(train_loader) * 100)), end="")

        torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': cost,
        }, os.path.join(save_dir_path,'model_epoch{0}_{1}.pt'.format(epochs,round(cost.item(),4))))

        for count, (image, label) in enumerate(val_loader):
            image = image.to(device)
            label = label.to(device)

            model.eval()
            with torch.no_grad():
                output = model(image).to(device)
                val_loss = criterion(output,label)
                valid_losses.append(val_loss.item())
                _, predicted = torch.max(output.data,1)
                val_label_eval = torch.argmax(label,1)
                val_total += label.size(0)
                val_correct +=(predicted == val_label_eval).sum().item()
            print('{0}% 완료,\r'.format(int(count / len(val_loader) * 100)), end="")

        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            print("best model renew.")
            torch.save({
                'epoch': epochs,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': cost,
            }, os.path.join(save_dir_path, 'model_best.pt'))
        train_accuracy = round(float(100*train_correct/train_total),4)
        val_accuracy = round(float(100*val_correct/val_total),4)

        # epoch당 평균 loss 계산
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        print("Epoch : ",epochs, " Training Loss : ",round(cost.item(),4), " Training Accuracy : ",train_accuracy," Validation Loss : ", round(val_loss.item(),4), " Validation Accuracy : ",val_accuracy )
        word = "Epoch : {0} Training Loss : {1} Training Accuracy {2} Validation Loss : {3} Validation Accuracy {4} \n".format(epochs,round(cost.item(),4),train_accuracy,round(val_loss.item(),4),val_accuracy)
        file_txt.write(word)
        time_2 = timer()
        print("Train time / epoch : ", time_2 - time_1)

        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    file_txt.close()
예제 #6
0
def main(args):
    BATCH_SIZE = args.batch_size
    LR = args.learning_rate
    EPOCH = args.epoch

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    use_gpu = torch.cuda.is_available()

    data_transforms = {
        transforms.Compose([
            transforms.Resize(320),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }
    transform = transforms.Compose([
        transforms.Resize(size=(227, 227)),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = torchvision.datasets.ImageFolder(root=args.train_images,
                                                     transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

    # 从文件夹中读取validation数据
    validation_dataset = torchvision.datasets.ImageFolder(
        root=args.test_images, transform=transform)
    print(validation_dataset.class_to_idx)

    test_loader = torch.utils.data.DataLoader(validation_dataset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)

    if args.model_name == "densenet":
        Net = densenet.DenseNet().to(device)
    if args.model_name == "alexnet":
        Net = alexnet.AlexNet().to(device)
    if args.model_name == "googlenet":
        Net = googlenet.GoogLeNet().to(device)
    if args.model_name == "mobilenet":
        Net = mobilenet.MobileNetV2().to(device)
    if args.model_name == "mnasnet":
        Net = mnasnet.mnasnet1_0().to(device)
    if args.model_name == "squeezenet":
        Net = squeezenet.SqueezeNet().to(device)
    if args.model_name == "resnet":
        Net = resnet.resnet50().to(device)
    if args.model_name == "vgg":
        Net = vgg.vgg19().to(device)
    if args.model_name == "shufflenetv2":
        Net = shufflenetv2.shufflenet_v2_x1_0().to(device)

    criterion = nn.CrossEntropyLoss()
    opti = torch.optim.Adam(Net.parameters(), lr=LR)

    if __name__ == '__main__':
        Accuracy_list = []
        Loss_list = []

        for epoch in range(EPOCH):
            sum_loss = 0.0
            correct1 = 0

            total1 = 0
            for i, (images, labels) in enumerate(train_loader):
                num_images = images.size(0)

                images = Variable(images.to(device))
                labels = Variable(labels.to(device))

                if args.model_name == 'googlenet':
                    out = Net(images)
                    out = out[0]
                else:
                    out = Net(images)
                _, predicted = torch.max(out.data, 1)

                total1 += labels.size(0)

                correct1 += (predicted == labels).sum().item()

                loss = criterion(out, labels)
                print(loss)
                opti.zero_grad()
                loss.backward()
                opti.step()

                # 每训练100个batch打印一次平均loss
                sum_loss += loss.item()
                if i % 10 == 9:
                    print('train loss [%d, %d] loss: %.03f' %
                          (epoch + 1, i + 1, sum_loss / 2000))
                    print("train acc %.03f" % (100.0 * correct1 / total1))
                    sum_loss = 0.0
            Accuracy_list.append(100.0 * correct1 / total1)
            print('accurary={}'.format(100.0 * correct1 / total1))
            Loss_list.append(loss.item())

        x1 = range(0, EPOCH)
        x2 = range(0, EPOCH)
        y1 = Accuracy_list
        y2 = Loss_list

        total_test = 0
        correct_test = 0
        for i, (images, labels) in enumerate(test_loader):
            start_time = time.time()
            print('time_start', start_time)
            num_images = images.size(0)
            print('num_images', num_images)
            images = Variable(images.to(device))
            labels = Variable(labels.to(device))
            print("GroundTruth", labels)
            if args.model_name == 'googlenet':
                out = Net(images)[0]
                out = out[0]
            else:
                out = Net(images)
            _, predicted = torch.max(out.data, 1)
            print("predicted", predicted)
            correct_test += (predicted == labels).sum().item()
            total_test += labels.size(0)
            print('time_usage', (time.time() - start_time) / args.batch_size)
        print('total_test', total_test)
        print('correct_test', correct_test)
        print('accurary={}'.format(100.0 * correct_test / total_test))

        plt.subplot(2, 1, 1)
        plt.plot(x1, y1, 'o-')
        plt.title('Train accuracy vs. epoches')
        plt.ylabel('Train accuracy')
        plt.subplot(2, 1, 2)
        plt.plot(x2, y2, '.-')
        plt.xlabel('Train loss vs. epoches')
        plt.ylabel('Train loss')
        # plt.savefig("accuracy_epoch" + str(EPOCH) + ".png")
        plt.savefig(args.output_dir + '/' + 'accuracy_epoch' + str(EPOCH) +
                    '.png')
        plt.show()
        torch.save(args.output_dir, args.model_name + '.pth')