Exemple #1
0
    def __init__(self, opt):
        super(Trainer, self).__init__(opt)

        if self.isTrain and not opt.continue_train:
            self.model = resnet50(pretrained=True)
            self.model.fc = nn.Linear(2048, 1)
            torch.nn.init.normal_(self.model.fc.weight.data, 0.0,
                                  opt.init_gain)

        if not self.isTrain or opt.continue_train:
            self.model = resnet50(num_classes=1)

        if self.isTrain:
            self.loss_fn = nn.BCEWithLogitsLoss()
            # initialize optimizers
            if opt.optim == 'adam':
                self.optimizer = torch.optim.Adam(self.model.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            elif opt.optim == 'sgd':
                self.optimizer = torch.optim.SGD(self.model.parameters(),
                                                 lr=opt.lr,
                                                 momentum=0.0,
                                                 weight_decay=0)
            else:
                raise ValueError("optim should be [adam, sgd]")

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.epoch)
        self.model.to(opt.gpu_ids[0])
Exemple #2
0
 def __init__(self, proposalN, num_classes, channels):
     # nn.Module子类的函数必须在构造函数中执行父类的构造函数
     super(MainNet, self).__init__()
     self.num_classes = num_classes
     self.proposalN = proposalN
     self.pretrained_model_bp = resnet.resnet50(pretrained=True,
                                                pth_path=pretrain_path)
     self.pretrained_model = resnet.resnet50(pretrained=True,
                                             pth_path=pretrain_path)
     self.rawcls_net = nn.Linear(channels, num_classes)
     self.localcls_net = nn.Linear(channels, 1)
     self.APPM = APPM()
Exemple #3
0
    def load_pretrain(num_classes, device):
        model_pre = resnet50(num_classes=1000, pretrained=True) # imagenet pretrained, numclasses=1000
        if num_classes==1000:
            return model_pre.to(device)

        else:
            model = resnet50(num_classes=num_classes, pretrained=False)
            params_pre = model_pre.state_dict().copy()
            params = model.state_dict()
            for i in params_pre:
                if not i.startswith('fc'):
                    params[i] = params_pre[i]
            model.load_state_dict(params)
            return model.to(device)
Exemple #4
0
def main(opt, myargs):
  model = resnet50(num_classes=1)
  state_dict = torch.load(opt.model_path, map_location='cpu')
  model.load_state_dict(state_dict['model'])
  if(not opt.use_cpu):
    model.cuda()
  model.eval()

  # Transform
  trans_init = []
  if(opt.crop is not None):
    trans_init = [transforms.CenterCrop(opt.crop),]
    print('Cropping to [%i]'%opt.crop)
  else:
    print('Not cropping')
  trans = transforms.Compose(trans_init + [
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])

  img = trans(Image.open(opt.file).convert('RGB'))

  with torch.no_grad():
      in_tens = img.unsqueeze(0)
      if(not opt.use_cpu):
        in_tens = in_tens.cuda()
      prob = model(in_tens).sigmoid().item()

  print('probability of being synthetic: {:.2f}%'.format(prob * 100))
  pass
    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        self.resnet =resnet50(pretrained=True)  # pretrained ImageNet ResNet-101
        # Remove linear and pool layers (since we're not doing classification)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
Exemple #6
0
def main(opt, myargs):
  # Running tests
  # opt = TestOptions().parse(print_options=False)
  model_name = os.path.basename(opt.model_path).replace('.pth', '')
  rows = [["{} model testing on...".format(model_name)],
          ['testset', 'accuracy', 'avg precision']]

  print("{} model testing on...".format(model_name))
  for v_id, val_dict in enumerate(opt.vals):
    val = list(val_dict.keys())[0]
    print(f'Model: {val}')
    opt.dataroot = '{}/{}'.format(opt.datadir, val)
    opt.classes = os.listdir(opt.dataroot) if val_dict[val] else ['']
    opt.no_resize = True  # testing without resizing by default

    model = resnet50(num_classes=1)
    state_dict = torch.load(opt.model_path, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    model.cuda()
    model.eval()

    acc, ap, _, _, _, _ = validate(model, opt, stdout=myargs.stdout)
    rows.append([val, acc, ap])
    print("({}) acc: {}; ap: {}".format(val, acc, ap))

  csv_name = myargs.args.outdir + '/{}.csv'.format(model_name)
  with open(csv_name, 'w') as f:
    csv_writer = csv.writer(f, delimiter=',')
    csv_writer.writerows(rows)
    def __init__(self, criterion, args):
        super(Adversary, self).__init__()

        self.net = resnet.resnet50(pretrained=args.pretrained, args=args)
        self.net.fc = torch.nn.Linear(2048, args.num_classes)
        self.criterion = criterion
        self.method = getattr(self, args.method)
Exemple #8
0
def create_model(ema=False):
    net = resnet.resnet50(pretrained=True)
    net.fc = torch.nn.Linear(net.fc.in_features, args.num_class)
    net = net.cuda()

    if ema:
        for param in net.parameters():
            param.detach_()
    return net
Exemple #9
0
def main():

    global args, best_prec1
    args = parser.parse_args()
    print(args)
    best_acc = 0
    # Create dataloader
    print('====> Creating dataloader...')

    data_dir = args.data
    train_list = args.trainlist
    dataset_name = args.dataset

    train_loader = get_dataset(dataset_name, data_dir, train_list)
    # load network
    if args.backbone == 'resnet50':
        model = resnet50(num_classes=args.num_classes)
    elif args.backbone == 'resnet101':
        model = resnet101(num_classes=args.num_classes)

    if args.weights != '':
        try:
            ckpt = torch.load(args.weights)
            model.module.load_state_dict(ckpt['state_dict'])
            print('!!!load weights success !! path is ', args.weights)
        except Exception as e:
            model_init(args.weights, model)
    model = torch.nn.DataParallel(model)
    model.cuda()
    mkdir_if_missing(args.save_dir)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=10e-3)
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    for epoch in range(args.start_epoch, args.epochs + 1):
        # try not adjust learning rate
        #adjust_lr(optimizer, epoch)
        train(train_loader, model, criterion, optimizer, epoch)

        if epoch % args.val_step == 0:
            save_checkpoint(model, epoch, optimizer)
        '''
        if epoch% args.val_step == 0:
            acc = validate(test_loader, model, criterion)
            is_best = acc > best_acc
            best_acc = max(acc, best_acc)
            save_checkpoint({
                    'state_dict': model.module.state_dict(),
                    'epoch': epoch,
                }, is_best=is_best,train_batch=60000, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch) + '.pth.tar')
        '''

    return
Exemple #10
0
 def __init__(self, proposalN, num_classes, channels):
     # nn.Module子类的函数必须在构造函数中执行父类的构造函数
     super(MainNetMultitask, self).__init__()
     self.num_classes_1 = num_classes[0]
     self.num_classes_2 = num_classes[1]
     self.proposalN = proposalN
     self.pretrained_model = resnet.resnet50(pretrained=True,
                                             pth_path=pretrain_path)
     self.rawcls_net_1 = nn.Linear(channels, num_classes[0])
     self.rawcls_net_2 = nn.Linear(channels, num_classes[1])
     self.APPM = APPM()
def load_classifier(model_path, gpu_id):
    if torch.cuda.is_available() and gpu_id != -1:
        device = 'cuda:{}'.format(gpu_id)
    else:
        device = 'cpu'
    model = resnet50(num_classes=1)
    state_dict = torch.load(model_path, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    model.to(device)
    model.device = device
    model.eval()
    return model
Exemple #12
0
    def __init__(self):
        args = ParserArgs().args
        cuda_visible(args.gpu)

        model = resnet50(in_channels=1, num_classes=2)
        model = nn.DataParallel(model).cuda()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

        # Optionally resume from a checkpoint
        if args.resume:
            ckpt_root = os.path.join('/root/workspace', args.project,
                                     'checkpoints')
            ckpt_path = os.path.join(ckpt_root, args.resume)
            if os.path.isfile(ckpt_path):
                print("=> loading checkpoint '{}'".format(args.resume))
            #     checkpoint = torch.load(ckpt_path)
            #     args.start_epoch = checkpoint['epoch']
            #     self.val_best_iou = checkpoint['best_iou']
            #     model.load_state_dict(checkpoint['state_dict'])
            #     optimizer.load_state_dict(checkpoint['optimizer'])
            #     print("=> loaded checkpoint '{}' (epoch {})"
            #           .format(args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        cudnn.benchmark = True

        self.vis = Visualizer(env='{}'.format(args.version), port=args.port)

        self.train_loader = ultraLoader(root=args.dataroot,
                                        batch=args.batch,
                                        version='train').data_load()
        self.val_loader = ultraLoader(root=args.dataroot,
                                      batch=args.batch,
                                      version='validation').data_load()
        self.test_loader = ultraLoader(root=args.dataroot,
                                       batch=args.batch,
                                       version='test_ours').data_load()
        self.test_loader_bigan = ultraLoader(root=args.dataroot,
                                             batch=args.batch,
                                             version='bigan').data_load()
        self.test_loader_cyclegan = ultraLoader(
            root=args.dataroot, batch=args.batch,
            version='cyclegan').data_load()

        print_args(args)
        self.args = args
        self.model = model
        self.optimizer = optimizer
        self.criterion = nn.CrossEntropyLoss().cuda()
Exemple #13
0
    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        # resnet = torchvision.models.resnet101(pretrained=False)  # pretrained ImageNet ResNet-101
        self.resnet = resnet50(
            pretrained=True)  # pretrained ImageNet ResNet-101
        # Remove linear and pool layers (since we're not doing classification)
        # modules = list(resnet.children())[:-2]
        # self.resnet = nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d(
            (encoded_image_size, encoded_image_size))
Exemple #14
0
 def __init__(self, net_type, image_size=32, args=None):
     super(StrongDisc, self).__init__()
     self.net_type = net_type
     if net_type == 'inception_v3':
         self.net = inception_v3.inception_v3(pretrained=True,
                                              image_size=image_size)
     elif net_type == 'resnet18':
         self.net = resnet.resnet18(pretrained=True)
     elif net_type == 'resnet34':
         self.net = resnet.resnet34(pretrained=True)
     elif net_type == 'resnet50':
         self.net = resnet.resnet50(pretrained=True)
     elif net_type == 'resnet101':
         self.net = resnet.resnet101(pretrained=True)
     elif net_type == 'darts':
         self.net = darts.AugmentCNNOneOutput(model_path=args.darts_model)
     else:
         assert 0
Exemple #15
0
 def __init__(self,
              backbone_args,
              neck_args,
              bbox_head_args,
              train_cfg=None,
              test_cfg=None,
              pretrained=None):
     # super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
     #                            test_cfg, pretrained)
     super(FCOS, self).__init__()
     self.backbone = resnet50(**backbone_args)
     if neck_args is not None:
         self.neck = FPN(**neck_args)
     # bbox_head.update(train_cfg=train_cfg)
     # bbox_head.update(test_cfg=test_cfg)
     self.bbox_head = FCOSHead(**bbox_head_args)
     self.train_cfg = train_cfg
     self.test_cfg = test_cfg
     self.init_weights(pretrained=pretrained)
def main():
    global args
    args = parser.parse_args()
    print(args)
    # Create dataloader
    print('====> Creating dataloader...')

    query_dir = args.querypath
    query_list = args.querylist
    gallery_dir = args.gallerypath
    gallery_list = args.gallerylist
    dataset_name = args.dataset

    query_loader, gallery_loader = get_dataset(dataset_name, query_dir,
                                               query_list, gallery_dir,
                                               gallery_list)
    # load network
    if args.backbone == 'resnet50':
        model = resnet50(num_classes=args.num_classes)
    elif args.backbone == 'resnet101':
        model = resnet101(num_classes=args.num_classes)

    print(args.weights)

    if args.weights != '':
        try:
            model = torch.nn.DataParallel(model)
            ckpt = torch.load(args.weights)
            model.load_state_dict(ckpt['state_dict'])
            print('!!!load weights success !!! path is ', args.weights)
        except Exception as e:
            print('!!!load weights failed !!! path is ', args.weights)
            return
    else:
        print('!!!Load Weights PATH ERROR!!!')
        return
    model.cuda()
    mkdir_if_missing(args.save_dir)

    cudnn.benchmark = True
    evaluate(query_loader, gallery_loader, model)

    return
Exemple #17
0
def main(args):
    np.random.seed(123)

    log_out_dir = 'logs'
    os.makedirs(log_out_dir, exist_ok=True)
    time_stamp = time.strftime("%Y%m%d-%H%M%S")

    log_dir = os.path.join(log_out_dir, 'log-' + time_stamp + '.txt')
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s %(message)s',
                        filename=log_dir,
                        filemode='w')

    print("Using GPUs:", args.gpus)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    # Parameters Setting
    batch_size: int = 32
    num_workers: int = 4
    lr: float = 1e-3
    current_delta: float = 0.7
    flip_threshold = np.ones(args.nepochs) * 0.5
    initial_threshold = np.array([0.8, 0.8, 0.7, 0.6])
    flip_threshold[:len(initial_threshold)] = initial_threshold[:]

    # data augmentation
    transform_train = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    data_root = args.data_root
    trainset = Clothing1M(data_root=data_root,
                          split='train',
                          transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               drop_last=True)

    valset = Clothing1M(data_root=data_root,
                        split='val',
                        transform=transform_test)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=batch_size * 4,
                                             shuffle=False,
                                             num_workers=num_workers)

    testset = Clothing1M(data_root=data_root,
                         split='test',
                         transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=batch_size * 4,
                                              shuffle=False,
                                              num_workers=num_workers)

    num_class = 14

    f = resnet50(num_classes=num_class, pretrained=True)
    f = nn.DataParallel(f)
    f.to(device)

    print("\n")
    print("============= Parameter Setting ================")
    print("Using Clothing1M dataset")
    print("Training Epoch : {} | Batch Size : {} | Learning Rate : {} ".format(
        args.nepochs, batch_size, lr))
    print("================================================")
    print("\n")

    print("============= Start Training =============")
    print("Start Label Correction at Epoch : {}".format(args.warm_up))

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(f.parameters(),
                                lr=lr,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=5e-4)
    scheduler = MultiStepLR(optimizer, milestones=[6, 11], gamma=0.5)
    f_record = torch.zeros([args.rollWindow, len(trainset), num_class])

    test_acc = None
    best_val_acc = 0
    best_val_acc_epoch = 0
    best_test_acc = 0
    best_test_acc_epoch = 0
    best_weight = None

    for epoch in range(args.nepochs):
        train_loss = 0
        train_correct = 0
        train_total = 0

        f.train()
        for iteration, (features, labels, indices) in enumerate(
                tqdm(train_loader, ascii=True, ncols=50)):
            if features.shape[0] == 1:
                continue

            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = f(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_total += features.size(0)
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()

            f_record[epoch % args.rollWindow,
                     indices] = F.softmax(outputs.detach().cpu(), dim=1)

            # ----------------------------------------------------------------------
            # Evaluation if necessary
            if iteration % args.eval_freq == 0:
                print("\n>> Validation <<")
                f.eval()
                test_loss = 0
                test_correct = 0
                test_total = 0

                for _, (features, labels, indices) in enumerate(val_loader):
                    if features.shape[0] == 1:
                        continue

                    features, labels = features.to(device), labels.to(device)
                    outputs = f(features)
                    loss = criterion(outputs, labels)

                    test_loss += loss.item()
                    test_total += features.size(0)
                    _, predicted = outputs.max(1)
                    test_correct += predicted.eq(labels).sum().item()

                val_acc = test_correct / test_total * 100
                cprint(
                    ">> [Epoch: {}] Val Acc: {:3.3f}%\n".format(
                        epoch, val_acc), "blue")
                if best_val_acc < val_acc:
                    best_val_acc = val_acc
                    best_val_acc_epoch = epoch
                    best_weight = copy.deepcopy(f.state_dict())
                f.train()

        train_acc = train_correct / train_total * 100
        cprint(
            "Epoch [{}|{}] \t Train Acc {:.3f}%".format(
                epoch + 1, args.nepochs, train_acc), "yellow")
        cprint(
            "Epoch [{}|{}] \t Best Val Acc {:.3f}% \t Best Test Acc {:.3f}%".
            format(epoch + 1, args.nepochs, best_val_acc,
                   best_test_acc), "yellow")
        scheduler.step()

        if epoch >= args.warm_up:
            f_x = f_record.mean(0)
            y_tilde = trainset.targets

            y_corrected, current_delta = prob_correction(
                y_tilde,
                f_x,
                random_state=0,
                thd=0.1,
                current_delta=current_delta)

            logging.info('Current delta:\t{}\n'.format(current_delta))

            trainset.update_corrupted_label(y_corrected)

    # -- Final testing
    f.load_state_dict(best_weight)
    f.eval()

    test_loss = 0
    test_correct = 0
    test_total = 0

    for _, (features, labels, indices) in enumerate(test_loader):
        if features.shape[0] == 1:
            continue

        features, labels = features.to(device), labels.to(device)
        outputs = f(features)
        loss = criterion(outputs, labels)

        test_loss += loss.item()
        test_total += features.size(0)
        _, predicted = outputs.max(1)
        test_correct += predicted.eq(labels).sum().item()

    test_acc = test_correct / test_total * 100
    cprint(">> Test Acc: {:3.3f}%\n".format(test_acc), "yellow")
    if best_test_acc < test_acc:
        best_test_acc = test_acc
        best_test_acc_epoch = epoch
    print(">> Best validation accuracy: {:3.3f}%, at epoch {}".format(
        best_val_acc, best_val_acc_epoch))
    print(">> Best testing accuracy: {:3.3f}%, at epoch {}".format(
        best_test_acc, best_test_acc_epoch))
Exemple #18
0
    def __init__(self,
                 layers=50,
                 bins=(1, 2, 3, 6),
                 dropout=0.1,
                 classes=2,
                 zoom_factor=8,
                 use_ppm=True,
                 criterion=nn.CrossEntropyLoss(ignore_index=255),
                 BatchNorm=nn.BatchNorm2d,
                 pretrained=True):
        super(PSPNet, self).__init__()
        assert layers in [18, 50, 101, 152]
        assert 2048 % len(bins) == 0
        assert classes > 1
        assert zoom_factor in [1, 2, 4, 8]
        self.zoom_factor = zoom_factor
        self.use_ppm = use_ppm
        self.criterion = criterion
        models.BatchNorm = BatchNorm

        if layers == 50:
            resnet = models.resnet50(pretrained=pretrained)
        elif layers == 101:
            resnet = models.resnet101(pretrained=pretrained)
        elif layers == 18:
            resnet = models.resnet18(pretrained=pretrained)
        else:
            resnet = models.resnet152(pretrained=pretrained)

        if layers == 18:
            self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                        resnet.maxpool)
            self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

            for n, m in self.layer3.named_modules():
                if 'conv1' in n:
                    # print('find conv1', m)
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    # print('find downsample.0', m)
                    m.stride = (1, 1)
            for n, m in self.layer4.named_modules():
                if 'conv1' in n:
                    m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
            fea_dim = 512

        else:
            self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                        resnet.conv2, resnet.bn2, resnet.relu,
                                        resnet.conv3, resnet.bn3, resnet.relu,
                                        resnet.maxpool)
            self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

            for n, m in self.layer3.named_modules():
                if 'conv2' in n:
                    # print('find conv2',m)
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    # print('find downsample.0',m)
                    m.stride = (1, 1)
            for n, m in self.layer4.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
            fea_dim = 2048

        # print('======*********=============')

        if use_ppm:
            self.ppm = PPM(fea_dim, int(fea_dim / len(bins)), bins, BatchNorm)
            fea_dim *= 2
        self.cls = nn.Sequential(
            nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
            BatchNorm(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            # nn.Conv2d(512, classes, kernel_size=1)
            nn.Conv2d(512, classes, kernel_size=1)
            if classes > 2 else nn.Conv2d(512, 1, kernel_size=1))
        if self.training:
            self.aux = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False)
                if layers != 18 else nn.Conv2d(
                    256, 256, kernel_size=3, padding=1, bias=False),
                BatchNorm(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout),
                # nn.Conv2d(256, classes, kernel_size=1)
                nn.Conv2d(256, classes, kernel_size=1)
                if classes > 2 else nn.Conv2d(256, 1, kernel_size=1))
Exemple #19
0
def main(args):
    random_seed = args.seed
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True  # need to set to True as well

    print('Random Seed {}\n'.format(random_seed))

    # -- training parameters
    num_epoch = args.epoch
    milestone = [10, 20]
    batch_size = args.batch
    num_workers = 4

    weight_decay = 1e-3
    gamma = 0.1
    current_delta = args.delta

    lr = args.lr
    start_epoch = 0

    # -- specify dataset
    # data augmentation
    transform_train = transforms.Compose([
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    data_root = args.data_root
    trainset = Food101N(data_path=data_root, split='train', transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                              worker_init_fn=_init_fn, drop_last=True)

    testset = Food101N(data_path=data_root, split='test', transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size * 4, shuffle=False, num_workers=num_workers)

    num_class = 101

    print('train data size:', len(trainset))
    print('test data size:', len(testset))

    # -- create log file
    time_stamp = time.strftime("%Y%m%d-%H%M%S")
    file_name = 'Ours(' + time_stamp + ').txt'

    log_dir = 'food101_logs'
    os.makedirs('food101_logs', exist_ok=True)
    file_name = os.path.join(log_dir, file_name)
    saver = open(file_name, "w")

    saver.write(args.__repr__() + "\n\n")
    saver.flush()

    # -- set network, optimizer, scheduler, etc
    net = resnet50(num_classes=num_class, pretrained=True)
    net = nn.DataParallel(net)

    optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = net.to(device)

    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestone, gamma=gamma)
    criterion = torch.nn.CrossEntropyLoss()

    # -- misc
    iterations = 0
    f_record = torch.zeros([args.rollWindow, len(trainset), num_class])

    for epoch in range(start_epoch, num_epoch):
        train_correct = 0
        train_loss = 0
        train_total = 0

        net.train()

        for i, (images, labels, indices) in enumerate(trainloader):
            if images.size(0) == 1:  # when batch size equals 1, skip, due to batch normalization
                continue

            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_total += images.size(0)
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()

            f_record[epoch % args.rollWindow, indices] = F.softmax(outputs.detach().cpu(), dim=1)

            iterations += 1
            if iterations % 100 == 0:
                cur_train_acc = train_correct / train_total * 100.
                cur_train_loss = train_loss / train_total
                cprint('epoch: {}\titerations: {}\tcurrent train accuracy: {:.4f}\ttrain loss:{:.4f}'.format(
                    epoch, iterations, cur_train_acc, cur_train_loss), 'yellow')

                if iterations % 5000 == 0:
                    saver.write('epoch: {}\titerations: {}\ttrain accuracy: {}\ttrain loss: {}\n'.format(
                        epoch, iterations, cur_train_acc, cur_train_loss))
                    saver.flush()
            
            if iterations % args.eval_freq == 0:
                net.eval()
                test_total = 0
                test_correct = 0
                with torch.no_grad():
                    for i, (images, labels, _) in enumerate(testloader):
                        images, labels = images.to(device), labels.to(device)

                        outputs = net(images)

                        test_total += images.size(0)
                        _, predicted = outputs.max(1)
                        test_correct += predicted.eq(labels).sum().item()

                    test_acc = test_correct / test_total * 100.
                    
                cprint('>> Test accuracy: {:.4f}'.format(test_acc), 'cyan')
                saver.write('>> Test accuracy: {}\n'.format(test_acc))
                saver.flush()
                net.train()

        train_acc = train_correct / train_total * 100.

        cprint('epoch: {}'.format(epoch), 'yellow')
        cprint('train accuracy: {:.4f}\ntrain loss: {:.4f}'.format(train_acc, train_loss), 'yellow')
        saver.write('epoch: {}\ntrain accuracy: {}\ntrain loss: {}\n'.format(epoch, train_acc, train_loss))
        saver.flush()

        exp_lr_scheduler.step()

        if epoch >= args.warm_up:
            f_x = f_record.mean(0)
            y_tilde = trainset.targets

            y_corrected, current_delta = lrt_correction(y_tilde, f_x, current_delta=current_delta, delta_increment=0.05)

            saver.write('Current delta:\t{}\n'.format(current_delta))

            trainset.update_corrupted_label(y_corrected)

    saver.close()
Exemple #20
0
parser.add_argument('-m',
                    '--model_path',
                    type=str,
                    default='weights/blur_jpg_prob0.5.pth')
parser.add_argument('-c',
                    '--crop',
                    type=int,
                    default=None,
                    help='by default, do not crop. specify crop size')
parser.add_argument('--use_cpu',
                    action='store_true',
                    help='uses gpu by default, turn on to use cpu')

opt = parser.parse_args()

model = resnet50(num_classes=1)
state_dict = torch.load(opt.model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
if (not opt.use_cpu):
    model.cuda()
model.eval()

# Transform
trans_init = []
if (opt.crop is not None):
    trans_init = [
        transforms.CenterCrop(opt.crop),
    ]
    print('Cropping to [%i]' % opt.crop)
else:
    print('Not cropping')
Exemple #21
0
def main():
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch Clothing1M')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=256,
                        help='input batch size for testing (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='number of epochs to train (default: 120)')
    parser.add_argument('--gpu_id',
                        type=int,
                        default=0,
                        help='index of gpu to use (default: 0)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='init learning rate (default: 0.1)')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='random seed (default: 0)')
    parser.add_argument('--save',
                        action='store_true',
                        default=False,
                        help='For saving softmax_out_avg')
    parser.add_argument('--SEAL',
                        type=int,
                        default=0,
                        help='Phase of self-evolution')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    device = torch.device(
        'cuda:' + str(args.gpu_id) if torch.cuda.is_available() else 'cpu')

    # Datasets
    root = './data/Clothing1M'
    num_classes = 14
    kwargs = {
        'num_workers': 32,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    transform_train = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = Clothing1M(root, mode='train', transform=transform_train)
    val_dataset = Clothing1M(root, mode='val', transform=transform_test)
    test_dataset = Clothing1M(root, mode='test', transform=transform_test)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    softmax_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    def learning_rate(lr_init, epoch):
        optim_factor = 0
        if (epoch > 5):
            optim_factor = 1
        return lr_init * math.pow(0.1, optim_factor)

    def load_pretrain(num_classes, device):
        model_pre = resnet50(
            num_classes=1000,
            pretrained=True)  # imagenet pretrained, numclasses=1000
        if num_classes == 1000:
            return model_pre.to(device)

        else:
            model = resnet50(num_classes=num_classes, pretrained=False)
            params_pre = model_pre.state_dict().copy()
            params = model.state_dict()
            for i in params_pre:
                if not i.startswith('fc'):
                    params[i] = params_pre[i]
            model.load_state_dict(params)
            return model.to(device)

    # results
    results_root = os.path.join('results', 'clothing')
    if not os.path.isdir(results_root):
        os.makedirs(results_root)
    """ Test model """
    if args.SEAL == -1:
        model = resnet50().to(device)
        model.load_state_dict(
            torch.load(os.path.join(results_root, 'seed0_clothing_normal.pt')))
        test(args, model, device, test_loader)
    """ Get softmax_out_avg - normal training on noisy labels """
    if args.SEAL == 0:
        print(
            'The DMI model is trained using the official pytorch implemention of L_DMI <https://github.com/Newbeeer/L_DMI>.\n'
        )
    """ Self Evolution - training on softmax_out_avg from DMI model """
    if args.SEAL == 1:
        # Loading softmax_out_avg of last phase
        softmax_root = os.path.join(results_root, 'softmax_out_dmi.npy')
        softmax_out_avg = np.load(softmax_root).reshape(
            [-1, len(train_dataset), num_classes])
        softmax_out_avg = softmax_out_avg[:5].mean(
            axis=0
        )  # We found that the DMI model may not converged in the last 5 epochs.
        print('softmax_out_avg loaded from', softmax_root, ', shape: ',
              softmax_out_avg.shape)

        # Dataset with soft targets
        train_dataset_soft = Clothing1M_soft(root,
                                             targets_soft=torch.Tensor(
                                                 softmax_out_avg.copy()),
                                             mode='train',
                                             transform=transform_train)
        train_loader_soft = torch.utils.data.DataLoader(
            train_dataset_soft,
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)

        # Building model
        model = load_pretrain(num_classes, device)
        model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
        model.load_state_dict(
            torch.load(os.path.join(results_root, 'clothing_dmi.pt')))
        print('Initialize the model using DMI model.')

        # Training
        best_val_acc = 0
        save_path = os.path.join(
            results_root, 'seed' + str(args.seed) + '_clothing_dmi_SEAL1.pt')
        softmax_out = []
        for epoch in range(1, args.epochs + 1):
            optimizer = optim.SGD(model.parameters(),
                                  lr=learning_rate(args.lr, epoch),
                                  momentum=0.9,
                                  weight_decay=1e-3)
            train_soft(args, model, device, train_loader_soft, optimizer,
                       epoch)
            best_val_acc = val_test(args, model, device, val_loader,
                                    test_loader, best_val_acc, save_path)
            softmax_out.append(get_softmax_out(model, softmax_loader, device))

        if args.save:
            softmax_root = os.path.join(
                results_root,
                'seed' + str(args.seed) + '_softmax_out_dmi_SEAL1.npy')
            softmax_out = np.concatenate(softmax_out)
            np.save(softmax_root, softmax_out)
            print('new softmax_out saved to', softmax_root, ', shape: ',
                  softmax_out.shape)

    if args.SEAL >= 2:
        # Loading softmax_out_avg of last phase
        softmax_root = os.path.join(
            results_root, 'seed' + str(args.seed) + '_softmax_out_dmi_SEAL' +
            str(args.SEAL - 1) + '.npy')
        softmax_out_avg = np.load(softmax_root).reshape(
            [-1, len(train_dataset), num_classes])
        softmax_out_avg = softmax_out_avg.mean(axis=0)
        print('softmax_out_avg loaded from', softmax_root, ', shape: ',
              softmax_out_avg.shape)

        # Dataset with soft targets
        train_dataset_soft = Clothing1M_soft(root,
                                             targets_soft=torch.Tensor(
                                                 softmax_out_avg.copy()),
                                             mode='train',
                                             transform=transform_train)
        train_loader_soft = torch.utils.data.DataLoader(
            train_dataset_soft,
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)

        # Building model
        model = load_pretrain(num_classes, device)
        model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
        model_path = os.path.join(
            results_root, 'seed' + str(args.seed) + '_clothing_dmi_SEAL' +
            str(args.SEAL - 1) + '.pt')
        model.load_state_dict(torch.load(model_path))
        print('Initialize the model using {}.'.format(model_path))

        # Training
        best_val_acc = 0
        save_path = os.path.join(
            results_root, 'seed' + str(args.seed) + '_clothing_dmi_SEAL' +
            str(args.SEAL) + '.pt')
        softmax_out = []
        for epoch in range(1, args.epochs + 1):
            optimizer = optim.SGD(model.parameters(),
                                  lr=learning_rate(args.lr, epoch),
                                  momentum=0.9,
                                  weight_decay=1e-3)
            train_soft(args, model, device, train_loader_soft, optimizer,
                       epoch)
            best_val_acc = val_test(args, model, device, val_loader,
                                    test_loader, best_val_acc, save_path)
            softmax_out.append(get_softmax_out(model, softmax_loader, device))

        if args.save:
            softmax_root = os.path.join(
                results_root, 'seed' + str(args.seed) +
                '_softmax_out_dmi_SEAL' + str(args.SEAL) + '.npy')
            softmax_out = np.concatenate(softmax_out)
            np.save(softmax_root, softmax_out)
            print('new softmax_out saved to', softmax_root, ', shape: ',
                  softmax_out.shape)
Exemple #22
0
def main():
    args = get_args()
    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                     init_method='env://')

    if args.local_rank == 0:
        log_format = '[%(asctime)s] %(message)s'
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
            format=log_format, datefmt='%d %I:%M:%S')
        t = time.time()
        local_time = time.localtime(t)
        if not os.path.exists('{}'.format(args.save)):
            os.makedirs('{}'.format(args.save))
        fh = logging.FileHandler(os.path.join('{}/log.train-{}-{}-{}-{}-{}-{}'.format(args.save, \
                    local_time.tm_year, local_time.tm_mon, local_time.tm_mday, \
                    local_time.tm_hour, local_time.tm_min, local_time.tm_sec)))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info(args)

    if not args.test_only:
        assert os.path.exists(args.train_dir)
        args.train_dataloader = get_train_dataloader(args.train_dir, \
             args.batch_size//args.gpu_num, args.total_epoch,args.local_rank)

    assert os.path.exists(args.val_dir)
    if args.local_rank == 0:
        args.val_dataloader = get_val_dataloader(args.val_dir)

    print('rank {:d}: load data successfully'.format(args.local_rank))

    model = resnet50()
    optimizer = sgd_optimizer(model, args.learning_rate, args.momentum, args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                    [30, 60, 90, 120],  0.1)

    if args.checkpoint_dir is not None:
        state_dict = torch.load(args.checkpoint_dir, map_location='cpu')
        model.load_state_dict(state_dict['model'])
        if 'optimizer' in state_dict.keys():
            optimizer.load_state_dict(state_dict['optimizer'])
        if 'scheduler' in state_dict.keys():
            scheduler.load_state_dict(state_dict['scheduler'])
        if 'iteration' in state_dict.keys():
            start_epoch = state_dict['iteration']
    else:
        start_epoch = 0

    args.loss_function = LabelSmoothCrossEntropyLoss().cuda()
    device = torch.device("cuda")
    model.to(device)
    for name, param in model.named_parameters():
        if 'momentum_buffer' in optimizer.state[param]:
            optimizer.state[param]['momentum_buffer'] = optimizer.state[param]['momentum_buffer'].cuda() 
        
    model = torch.nn.parallel.DistributedDataParallel(model,  device_ids=[args.local_rank], \
        output_device=args.local_rank, broadcast_buffers=False)

    args.optimizer = optimizer
    args.scheduler = scheduler

    if not args.test_only:
        train(model, device, args, start_epoch=start_epoch+1)

    if args.local_rank == 0:
        validate(model, device, args)
def main():
    # Settings
    parser = argparse.ArgumentParser(description='PyTorch Clothing1M')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='number of epochs to train')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        help='init learning rate')
    parser.add_argument('--save_model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument(
        '--use_noisy_val',
        action='store_true',
        default=False,
        help=
        'Using the noisy validation setting. By default, using the benchmark setting.'
    )
    parser.add_argument('--init_path',
                        type=str,
                        default=None,
                        help='Path of a pretrained model)')
    parser.add_argument('--teacher_path',
                        type=str,
                        default=None,
                        help='Path of the teacher model')
    parser.add_argument('--soft_targets',
                        type=bool,
                        default=True,
                        help='Use soft targets')
    parser.add_argument('--n_gpu',
                        type=int,
                        default=2,
                        help='number of gpu to use')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=256,
                        help='input batch size for testing')
    parser.add_argument('--root',
                        type=str,
                        default='data/Clothing1M/',
                        help='root of dataset')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    args = parser.parse_args()

    if args.teacher_path is None:
        exp_name = 'clothing1m_batch{}_seed{}'.format(args.batch_size,
                                                      args.seed)
    else:
        teacher_name = args.teacher_path.replace('models/', '')
        teacher_name = teacher_name[:teacher_name.find('_')]
        if 'net1' in args.teacher_path:
            teacher_name = teacher_name + 'net1'
        elif 'net2' in args.teacher_path:
            teacher_name = teacher_name + 'net2'
        if args.soft_targets:
            exp_name = 'softstudent_of_{}_clothing1m_batch{}_seed{}'.format(
                teacher_name, args.batch_size, args.seed)
        else:
            exp_name = 'student_of_{}_clothing1m_batch{}_seed{}'.format(
                teacher_name, args.batch_size, args.seed)
        if args.init_path is None:
            args.init_path = args.teacher_path

    if args.use_noisy_val:
        exp_name = 'nv_' + exp_name
    logpath = '{}.txt'.format(exp_name)
    log(logpath, 'Settings: {}\n'.format(args))

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # soft loss
    def soft_cross_entropy(output, target):
        output = F.log_softmax(output, dim=1)
        loss = -torch.mean(torch.sum(output * target, dim=1))
        return loss

    # Datasets
    root = args.root
    num_classes = 14
    kwargs = {
        'num_workers': 32,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    train_transform = transforms.Compose([
        transforms.Resize((256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371),
                             (0.3113, 0.3192, 0.3214)),
    ])
    test_transform = transforms.Compose([
        transforms.Resize((256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371),
                             (0.3113, 0.3192, 0.3214)),
    ])

    train_dataset = Clothing1M(root,
                               mode='train',
                               transform=train_transform,
                               use_noisy_val=args.use_noisy_val)
    val_dataset = Clothing1M(root,
                             mode='val',
                             transform=test_transform,
                             use_noisy_val=args.use_noisy_val)
    test_dataset = Clothing1M(root,
                              mode='test',
                              transform=test_transform,
                              use_noisy_val=args.use_noisy_val)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    if args.teacher_path is not None:
        teacher_model = resnet50(num_classes=num_classes).to(device)
        teacher_model = torch.nn.DataParallel(teacher_model,
                                              device_ids=list(range(
                                                  args.n_gpu)))
        state_dict = torch.load(args.teacher_path)
        if not list(state_dict.keys())[0][:7] == 'module.':
            state_dict = dict(('module.' + key, value)
                              for (key, value) in state_dict.items())
        teacher_model.load_state_dict(state_dict)
        distill_dataset = Clothing1M(root,
                                     mode='train',
                                     transform=test_transform,
                                     use_noisy_val=args.use_noisy_val)
        if args.soft_targets:
            pred = get_pred(teacher_model,
                            device,
                            distill_dataset,
                            args.test_batch_size,
                            num_workers=32,
                            output_softmax=True)
            train_criterion = soft_cross_entropy
        else:
            pred = get_pred(teacher_model,
                            device,
                            distill_dataset,
                            args.test_batch_size,
                            num_workers=32)
            train_criterion = F.cross_entropy
        train_dataset.targets = pred
        log(logpath, 'Get label from teacher {}.\n'.format(args.teacher_path))
        del teacher_model
    else:
        train_criterion = F.cross_entropy

    # Building model
    def learning_rate(lr_init, epoch):
        optim_factor = 0
        if (epoch > 5):
            optim_factor = 1
        return lr_init * math.pow(0.1, optim_factor)

    model = resnet50(pretrained=True)
    model.fc = nn.Linear(2048, num_classes)
    model = torch.nn.DataParallel(model.to(device),
                                  device_ids=list(range(args.n_gpu)))
    if args.init_path is not None:
        state_dict = torch.load(args.init_path)
        if not list(state_dict.keys())[0][:7] == 'module.':
            state_dict = dict(('module.' + key, value)
                              for (key, value) in state_dict.items())
        model.load_state_dict(state_dict)
        _, test_acc = test(args,
                           model,
                           device,
                           test_loader,
                           criterion=F.cross_entropy)
        log(logpath,
            'Initialized testing accuracy: {:.2f}\n'.format(100 * test_acc))
    cudnn.benchmark = True  # Accelerate training by enabling the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=1e-3)

    # Training
    save_every_epoch = True
    if save_every_epoch:
        vals = []
        directory = 'models/' + exp_name
        if not os.path.exists(directory):
            os.makedirs(directory)

    val_best, epoch_best, test_at_best = 0, 0, 0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        lr = learning_rate(args.lr, epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        _, train_acc = train(args,
                             model,
                             device,
                             train_loader,
                             optimizer,
                             epoch,
                             criterion=train_criterion)
        _, val_acc = test(args,
                          model,
                          device,
                          val_loader,
                          criterion=F.cross_entropy)
        _, test_acc = test(args,
                           model,
                           device,
                           test_loader,
                           criterion=F.cross_entropy)
        if val_acc > val_best:
            val_best, test_at_best, epoch_best = val_acc, test_acc, epoch
            if args.save_model:
                torch.save(model.state_dict(), '{}_best.pth'.format(exp_name))
        if save_every_epoch:
            vals.append(val_acc)
            torch.save(model.state_dict(),
                       '{}/epoch{}.pth'.format(directory, epoch))

        log(
            logpath,
            'Epoch: {}/{}, Time: {:.1f}s. '.format(epoch, args.epochs,
                                                   time.time() - t0))
        log(
            logpath,
            'Train: {:.2f}%, Val: {:.2f}%, Test: {:.2f}%; Val_best: {:.2f}%, Test_at_best: {:.2f}%, Epoch_best: {}\n'
            .format(100 * train_acc, 100 * val_acc, 100 * test_acc,
                    100 * val_best, 100 * test_at_best, epoch_best))

    if save_every_epoch:
        np.save('{}/val.npy'.format(directory), vals)
def main():
    global args
    args = parser.parse_args()

    # TODO model arguments module should be more easy to write and read
    if args.approach == 'lwf':
        approach = lwf
        assert (args.memory_size is None)
        assert (args.memory_mini_batch_size is None)
    elif args.approach == 'joint_train':
        approach = joint_train
        assert (args.memory_size is None)
        assert (args.memory_mini_batch_size is None)
    elif args.approach == 'fine_tuning':
        approach = fine_tuning
        assert (args.memory_size is None)
        assert (args.memory_mini_batch_size is None)
    elif args.approach == 'gem':
        approach = gem
        assert (args.memory_size is not None)
        assert (args.memory_mini_batch_size is None)
    else:
        approach = None

    rank, world_size = dist_init('27777')

    if rank == 0:
        print('=' * 100)
        print('Arguments = ')
        for arg in vars(args):
            print('\t' + arg + ':', getattr(args, arg))
        print('=' * 100)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed)
    else:
        print('[CUDA unavailable]')
        sys.exit()

    # Generate Tasks
    args.batch_size = args.batch_size // world_size
    Tasks = generator.GetTasks(args.approach, args.batch_size, world_size, \
        memory_size=args.memory_size, memory_mini_batch_size=args.memory_mini_batch_size)
    # Network
    net = network.resnet50(pretrained=True).cuda()
    net = DistModule(net)
    # Approach
    Appr = approach.Approach(net, args, Tasks)

    # Solve tasks incrementally
    for t in range(len(Tasks)):
        task = Tasks[t]

        if rank == 0:
            print('*' * 100)
            print()
            print('Task {:d}: {:d} classes ({:s})'.format(
                t, task['class_num'], task['description']))
            print()
            print('*' * 100)

        Appr.solve(t, Tasks)

        if rank == 0:
            print('*' * 100)
            print('Task {:d}: {:d} classes Finished.'.format(
                t, task['class_num']))
            print('*' * 100)
def define_model(model_type,
                 pretrained_path='',
                 neighbour_slice=args.neighbour_slice,
                 input_type=args.input_type,
                 output_type=args.output_type):
    if input_type == 'diff_img':
        input_channel = neighbour_slice - 1
    else:
        input_channel = neighbour_slice

    if model_type == 'prevost':
        model_ft = generators.PrevostNet()
    elif model_type == 'resnext50':
        model_ft = resnext.resnet50(sample_size=2,
                                    sample_duration=16,
                                    cardinality=32)
        model_ft.conv1 = nn.Conv3d(in_channels=1,
                                   out_channels=64,
                                   kernel_size=(3, 7, 7),
                                   stride=(1, 2, 2),
                                   padding=(1, 3, 3),
                                   bias=False)
    elif model_type == 'resnext101':
        model_ft = resnext.resnet101(sample_size=2,
                                     sample_duration=16,
                                     cardinality=32)
        model_ft.conv1 = nn.Conv3d(in_channels=1,
                                   out_channels=64,
                                   kernel_size=(3, 7, 7),
                                   stride=(1, 2, 2),
                                   padding=(1, 3, 3),
                                   bias=False)
        # model_ft.conv1 = nn.Conv3d(neighbour_slice, 64, kernel_size=7, stride=(1, 2, 2),
        #                            padding=(3, 3, 3), bias=False)
    elif model_type == 'resnet152':
        model_ft = resnet.resnet152(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet101':
        model_ft = resnet.resnet101(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet50':
        model_ft = resnet.resnet50(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet34':
        model_ft = resnet.resnet34(pretrained=False)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'resnet18':
        model_ft = resnet.resnet18(pretrained=True)
        model_ft.conv1 = nn.Conv2d(input_channel,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   bias=False)
    elif model_type == 'mynet':
        model_ft = mynet.resnet50(sample_size=2,
                                  sample_duration=16,
                                  cardinality=32)
        model_ft.conv1 = nn.Conv3d(in_channels=1,
                                   out_channels=64,
                                   kernel_size=(3, 7, 7),
                                   stride=(1, 2, 2),
                                   padding=(0, 3, 3),
                                   bias=False)
    elif model_type == 'mynet2':
        model_ft = generators.My3DNet()
    elif model_type == 'p3d':
        model_ft = p3d.P3D63()
        model_ft.conv1_custom = nn.Conv3d(1,
                                          64,
                                          kernel_size=(1, 7, 7),
                                          stride=(1, 2, 2),
                                          padding=(0, 3, 3),
                                          bias=False)
    elif model_type == 'densenet121':
        model_ft = densenet.densenet121()
    else:
        print('network type of <{}> is not supported, use original instead'.
              format(network_type))
        model_ft = generators.PrevostNet()

    num_ftrs = model_ft.fc.in_features

    if model_type == 'mynet':
        num_ftrs = 384
    elif model_type == 'prevost':
        num_ftrs = 576

    if output_type == 'average_dof' or output_type == 'sum_dof':
        # model_ft.fc = nn.Linear(128, 6)
        model_ft.fc = nn.Linear(num_ftrs, 6)
    else:
        # model_ft.fc = nn.Linear(128, (neighbour_slice - 1) * 6)
        model_ft.fc = nn.Linear(num_ftrs, (neighbour_slice - 1) * 6)

    # if args.training_mode == 'finetune':
    #     model_path = path.join(results_dir, args.model_filename)
    #     if path.isfile(model_path):
    #         print('Loading model from <{}>...'.format(model_path))
    #         model_ft.load_state_dict(torch.load(model_path))
    #         print('Done')
    #     else:
    #         print('<{}> not exists! Training from scratch...'.format(model_path))

    if pretrained_path:
        if path.isfile(pretrained_path):
            print('Loading model from <{}>...'.format(pretrained_path))
            model_ft.load_state_dict(
                torch.load(pretrained_path, map_location='cuda:0'))
            # model_ft.load_state_dict(torch.load(pretrained_path))
            print('Done')
        else:
            print('<{}> not exists! Training from scratch...'.format(
                pretrained_path))
    else:
        print('Train this model from scratch!')

    model_ft.cuda()
    model_ft = model_ft.to(device)
    print('define model device {}'.format(device))
    return model_ft
import networks.resnet as resnet
from preprocess_SSL.SSL import model as ssl_model
from config import ROOT

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default="bottle")
parser.add_argument('--kmeans', type=int, default=128)
parser.add_argument('--type', type=str, default="all")
parser.add_argument('--index', type=int, default=30)
parser.add_argument('--image_size', type=int, default=256)
parser.add_argument('--patch_size', type=int, default=16)
parser.add_argument('--dim_reduction', type=str, default='PCA')
args = parser.parse_args()

scratch_model = nn.Sequential(
    resnet.resnet50(pretrained=False, num_classes=args.kmeans))
scratch_model = nn.DataParallel(scratch_model).cuda()

train_path = "{}/dataset/{}/train_resize/good/".format(ROOT, args.data)
train_dataset = dataloaders.MvtecLoader(train_path)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
### DataSet for all defect type
test_path = "{}/dataset/{}/test_resize/all/".format(ROOT, args.data)
test_dataset = dataloaders.MvtecLoader(test_path)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

test_good_path = "{}/dataset/{}/test_resize/good/".format(ROOT, args.data)
test_good_dataset = dataloaders.MvtecLoader(test_good_path)
test_good_loader = DataLoader(test_good_dataset, batch_size=1, shuffle=False)

mask_path = "{}/dataset/{}/ground_truth_resize/all/".format(ROOT, args.data)
def create_model():
    model = resnet50(pretrained=True)
    model.fc = nn.Linear(2048, args.num_class)
    model = model.cuda()
    return model