parser.add_argument('--k', type=float, default=None, help='k multiply batch_size')
parser.add_argument('--v', type=float, default=0.1, help='data variance')


args = parser.parse_args()
args.use_cuda = args.ngpu>0 and torch.cuda.is_available()

if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if args.use_cuda:
    torch.cuda.manual_seed_all(args.manualSeed)
cudnn.benchmark = True

args.prefix = time_file_str()

def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)
Exemplo n.º 2
0
def main():
    global args, best_prec1, best_prec5
    args = parser.parse_args()
    args.prefix = time_file_str()

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    log = open(
        os.path.join(args.save_dir, '{}.{}.log'.format(args.arch,
                                                       args.prefix)), 'w')
    log_top1err = open(
        os.path.join(args.save_dir,
                     '{}.{}.top1err-log'.format(args.arch, args.prefix)), 'w')
    log_top5err = open(
        os.path.join(args.save_dir,
                     '{}.{}.top5err-log'.format(args.arch, args.prefix)), 'w')

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)
    # create model
    model = model_dict[args.arch]
    print_log("[model] '{}'".format(args.arch), log)
    print_log("{}".format(model), log)
    print_log("[args parameter] : {}".format(args), log)

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    def image_loader_PIL(path):  #for test
        return Image.open(path).convert('RGB')

    class DSet(Dataset):
        def __init__(self,
                     image_list='',
                     transform=None,
                     loader=image_loader_PIL,
                     data_path=''):
            file = open(image_list, 'r')
            imgs = []
            for string in file:
                string = string.strip('\n')
                string = string.rstrip()
                sample = string.split()
                imgs.append((sample[0], int(sample[1])))
            self.imgs = imgs
            self.transform = transform
            self.loader = loader
            self.data_path = data_path

        def __getitem__(self, index):
            image_name, label = self.imgs[index]
            image_name = self.data_path + image_name
            img = self.loader(image_name)
            img = self.transform(img)
            return img, label

        def __len__(self):
            return len(self.imgs)

    train_data = DSet(image_list='your_train_data_list.txt',
                      transform=train_transform,
                      loader=image_loader_PIL,
                      data_path=args.train_data_path)
    val_data = DSet(image_list='your_test_data_list.txt',
                    transform=test_transform,
                    loader=image_loader_PIL,
                    data_path=args.test_data_path)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None

    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=(train_sampler is None),
                              num_workers=args.workers,
                              pin_memory=True,
                              sampler=train_sampler)
    val_loader = DataLoader(val_data,
                            batch_size=10,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if args.eval == 1:
        validate(val_loader, model, criterion, log, log_top1err, log_top5err)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log)

        # evaluate on validation set
        prec1, prec5 = validate(val_loader, model, criterion, log, log_top1err,
                                log_top5err)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        best_prec5 = max(prec5, best_prec5)
        print('[==epoch==]', epoch, '[==best_top1==] ', 100 - best_prec1,
              '[==best_top5==] ', 100 - best_prec5)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args)
def main():
    parser = argparse.ArgumentParser(description='FAVAE anomaly detection')
    parser.add_argument('--obj', type=str, default='.')
    parser.add_argument('--data_type', type=str, default='mvtec')
    parser.add_argument('--data_path', type=str, default='')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='maximum training epochs')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--validation_ratio', type=float, default=0.2)
    parser.add_argument('--grayscale',
                        action='store_true',
                        help='color or grayscale input image')
    parser.add_argument('--img_resize', type=int, default=128)
    parser.add_argument('--crop_size', type=int, default=128)
    parser.add_argument('--do_aug',
                        action='store_true',
                        help='whether to do data augmentation before training')
    parser.add_argument('--augment_num', type=int, default=10000)
    parser.add_argument('--p_rotate',
                        type=float,
                        default=0.3,
                        help='probability to do image rotation')
    parser.add_argument('--rotate_angle_vari',
                        type=float,
                        default=15.0,
                        help='rotate image between [-angle, +angle]')
    parser.add_argument('--p_rotate_crop',
                        type=float,
                        default=1.0,
                        help='probability to crop inner rotated image')
    parser.add_argument('--p_horizonal_flip',
                        type=float,
                        default=0.3,
                        help='probability to do horizonal flip')
    parser.add_argument('--p_vertical_flip',
                        type=float,
                        default=0.3,
                        help='probability to do vertical flip')
    parser.add_argument('--kld_weight', type=float, default=1.0)
    parser.add_argument('--lr',
                        type=float,
                        default=0.005,
                        help='learning rate of Adam')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.00001,
                        help='decay of Adam')
    parser.add_argument('--seed', type=int, default=None, help='manual seed')
    args = parser.parse_args()

    args.p_crop = 1 if args.crop_size != args.img_resize else 0
    args.train_data_dir = args.data_path + '/' + args.obj + '/train/good'
    args.aug_dir = './train_patches/' + args.obj + '/train/good'

    args.input_channel = 1 if args.grayscale else 3

    if args.seed is None:
        args.seed = random.randint(1, 10000)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed_all(args.seed)

    args.prefix = time_file_str()
    args.save_dir = './' + args.data_type + '/' + args.obj + '/vgg_feature' + '/seed_{}/'.format(
        args.seed)

    # data augmentation
    if not os.path.exists(args.aug_dir) and args.do_aug:
        os.makedirs(args.aug_dir)
        img_list = generate_image_list(args)
        augment_images(img_list, args)

    args.train_data_path = './train_patches'

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    log = open(
        os.path.join(args.save_dir,
                     'model_training_log_{}.txt'.format(args.prefix)), 'w')
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)

    # load model and dataset
    model = VAE(input_channel=args.input_channel, z_dim=100).to(device)
    teacher = models.vgg16(pretrained=True).to(device)
    for param in teacher.parameters():
        param.requires_grad = False

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    img_size = args.crop_size if args.img_resize != args.crop_size else args.img_resize
    kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
    train_dataset = MVTecDataset(args.train_data_path,
                                 class_name=args.obj,
                                 is_train=True,
                                 resize=img_size)
    img_nums = len(train_dataset)
    valid_num = int(img_nums * args.validation_ratio)
    train_num = img_nums - valid_num
    train_data, val_data = torch.utils.data.random_split(
        train_dataset, [train_num, valid_num])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=32,
                                             shuffle=False,
                                             **kwargs)

    test_dataset = MVTecDataset(args.data_path,
                                class_name=args.obj,
                                is_train=False,
                                resize=img_size)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=32,
                                              shuffle=True,
                                              **kwargs)

    # fetch fixed data for debugging
    x_normal_fixed, _, _ = iter(val_loader).next()
    x_normal_fixed = x_normal_fixed.to(device)

    x_test_fixed, _, _ = iter(test_loader).next()
    x_test_fixed = x_test_fixed.to(device)

    # start training
    save_name = os.path.join(args.save_dir,
                             '{}_{}_model.pt'.format(args.obj, args.prefix))
    early_stop = EarlyStop(patience=20, save_name=save_name)
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(1, args.epochs + 1):
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log(
            ' {:3d}/{:3d} ----- [{:s}] {:s}'.format(epoch, args.epochs,
                                                    time_string(), need_time),
            log)
        train(args, model, teacher, epoch, train_loader, optimizer, log)
        val_loss = val(args, model, teacher, epoch, val_loader, log)

        if (early_stop(val_loss, model, optimizer, log)):
            break

        if epoch % 10 == 0:
            save_sample = os.path.join(args.save_dir,
                                       '{}val-images.jpg'.format(epoch))
            save_sample2 = os.path.join(args.save_dir,
                                        '{}test-images.jpg'.format(epoch))
            save_snapshot(x_normal_fixed, x_test_fixed, model, save_sample,
                          save_sample2, log)

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
    log.close()
Exemplo n.º 4
0
def main():

  # Init logger
  args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log-seed-{:}-{:}.txt'.format(args.manualSeed, time_file_str())), 'w')
  print_log('save path : {:}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("Torch  version : {}".format(torch.__version__), log)
  print_log("CUDA   version : {}".format(torch.version.cuda), log)
  print_log("cuDNN  version : {}".format(cudnn.version()), log)
  print_log("Num of GPUs    : {}".format(torch.cuda.device_count()), log)
  print_log("Num of CPUs    : {}".format(multiprocessing.cpu_count()), log)

  config = load_config( args.config_path )
  genotype = Networks[ args.arch ]

  main_procedure(config, genotype, args.save_path, args.print_freq, log)
  log.close()