示例#1
0
    def __getitem__(self, index):
        if self.is_train:
            img, target = self.train_img[index], self.train_label[index]
            if len(img.shape) == 2:
                img = np.stack([img] * 3, 2)
            img = Image.fromarray(img, mode="RGB")
            img = transforms.Resize((256, 256), Image.BILINEAR)(img)
            img = transforms.RandomCrop(INPUT_SIZE)(img)
            img = transforms.RandomHorizontalFlip()(img)
            img = transforms.ToTensor()(img)
            img = transforms.Normalize([0.485, 0.456, 0.406],
                                       [0.229, 0.224, 0.225])(img)

        else:
            img, target = self.test_img[index], self.test_label[index]
            if len(img.shape) == 2:
                img = np.stack([img] * 3, 2)
            img = Image.fromarray(img, mode="RGB")
            img = transforms.Resize((256, 256), Image.BILINEAR)(img)
            img = transforms.CenterCrop(INPUT_SIZE)(img)
            img = transforms.ToTensor()(img)
            img = transforms.Normalize([0.485, 0.456, 0.406],
                                       [0.229, 0.224, 0.225])(img)

        return img, target
示例#2
0
    def __init__(self,
                 annFile: str,
                 imageDir: str,
                 targetHeight: int,
                 targetWidth: int,
                 numClass: int,
                 train: bool = True):
        self.annotations = {}
        self.table = {
            '不良-機械傷害': 0,
            '不良-著色不佳': 1,
            '不良-炭疽病': 2,
            '不良-乳汁吸附': 3,
            '不良-黑斑病': 4
        }
        self.imageDir = imageDir
        self.numClass = numClass

        with open(annFile, 'r', encoding='utf-8-sig') as f:
            for line in f.readlines():
                arr = line.rstrip().split(',')
                ans = []

                for idx in range(1, len(arr), 5):
                    tlx, tly, w, h, c = arr[idx:idx + 5]

                    if tlx:
                        tlx, tly, w, h = list(map(float, (tlx, tly, w, h)))
                        if c not in self.table:
                            self.table[c] = len(self.table)

                        cx = tlx + w / 2
                        cy = tly + h / 2
                        c = self.table[c]

                        ans.append(list(map(int, (cx, cy, w, h, c))))

                self.annotations[arr[0]] = ans

        self.names = list(self.annotations)

        with open('table.txt', 'w') as f:
            f.write(str(self.table))
            print(self.table)

        if train:
            self.transforms = T.Compose([
                T.RandomOrder([
                    T.RandomHorizontalFlip(),
                    T.RandomVerticalFlip(),
                    T.RandomSizeCrop(numClass)
                ]),
                T.Resize((targetHeight, targetWidth)),
                T.ColorJitter(brightness=.2, contrast=0, saturation=0, hue=0),
                T.Normalize()
            ])
        else:
            self.transforms = T.Compose(
                [T.Resize((targetHeight, targetWidth)),
                 T.Normalize()])
示例#3
0
def get_valloader_only(toy,
                       rsyncing,
                       batch_size=16,
                       num_workers=os.cpu_count() - 1,
                       notebook=False,
                       cat=False):
    """

    :param toy:
    :param rsyncing:
    :param batch_size:
    :param num_workers:
    :param notebook:
    :return:
    """
    num_workers = 0
    if rsyncing:
        print('Rsynced data! (prepare feat)', flush=True)
    else:
        print('Using symbolic links! (prepare feat)', flush=True)
    print('Getting path ready..', flush=True)
    _, anno_path_val, png_path = get_paths(rsyncing, toy, notebook)

    start_time = time.time()
    print('Creating Coco Dataset..', flush=True)

    if not cat:
        valset = u.dataset_coco(
            png_path,
            anno_path_val,
            transform=torchvision.transforms.Compose(
                [t.Normalize(), t.BboxCrop(targetsize=224)]),
            bbox_transform=torchvision.transforms.Compose([t.GetFiveBBs()]),
            for_feature=True,
            cat=cat)
    else:
        valset = u.dataset_coco(
            png_path,
            anno_path_val,
            transform=torchvision.transforms.Compose(
                [t.Normalize(), t.BboxCropMult(targetsize=224)]),
            bbox_transform=torchvision.transforms.Compose([t.GetBBsMult()]),
            for_feature=True,
            cat=cat)
    print('Validation set has', len(valset), 'images', flush=True)

    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=batch_size,
                                            sampler=SequentialSampler(valset),
                                            num_workers=num_workers,
                                            collate_fn=u.mammo_collate)
    print('Validation loader has', len(valloader), 'batches', flush=True)

    total_time = time.time() - start_time
    print('Creating Datasets took {:.0f} seconds.'.format(total_time),
          flush=True)
    return valloader
示例#4
0
    def __init__(self, root: str, annotation: str, targetHeight: int,
                 targetWidth: int, numClass: int):
        self.root = root
        self.coco = COCO(annotation)
        self.ids = list(self.coco.imgs.keys())

        self.targetHeight = targetHeight
        self.targetWidth = targetWidth
        self.numClass = numClass

        self.transforms = T.Compose([
            T.RandomOrder(
                [T.RandomHorizontalFlip(),
                 T.RandomSizeCrop(numClass)]),
            T.Resize((targetHeight, targetWidth)),
            T.ColorJitter(brightness=.2, contrast=.1, saturation=.1, hue=0),
            T.Normalize()
        ])

        self.newIndex = {}
        classes = []
        for i, (k, v) in enumerate(self.coco.cats.items()):
            self.newIndex[k] = i
            classes.append(v['name'])

        with open('classes.txt', 'w') as f:
            f.write(str(classes))
示例#5
0
def load_data(args):

    normalize = t.Normalize(mean=[0.445, 0.287, 0.190],
                            std=[0.31, 0.225, 0.168])
    im_transform = t.Compose([t.ToTensor(), normalize])

    # Use  the following code fo co_transformations e.g. random rotation or random flip etc.
    # co_transformer = cot.Compose([cot.RandomRotate(45)])

    dsetTrain = GIANA(args.imgdir,
                      args.gtdir,
                      input_size=(args.input_width, args.input_height),
                      train=True,
                      transform=im_transform,
                      co_transform=None,
                      target_transform=t.ToLabel())
    train_data_loader = data.DataLoader(dsetTrain,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers)

    dsetVal = GIANA(args.imgdir,
                    args.gtdir,
                    train=False,
                    transform=im_transform,
                    co_transform=None,
                    target_transform=t.ToLabel())
    val_data_loader = data.DataLoader(dsetVal,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers)
    return train_data_loader, val_data_loader
示例#6
0
 def forward(self, x):
     ts = transforms.Compose([
         T.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225],
                     channel=1)
     ])
     return self.feat_ext(ts(x))
示例#7
0
    def __init__(self,
                 root: str,
                 targetHeight: int,
                 targetWidth: int,
                 numClass: int,
                 train: bool = True):
        """
        :param root: should contain .jpg files and corresponding .txt files
        :param targetHeight: desired height for model input
        :param targetWidth: desired width for model input
        :param numClass: number of classes in the given dataset
        """
        self.cache = {}

        imagePaths = glob(os.path.join(root, '*.jpg'))
        for path in imagePaths:
            name = path.split('/')[-1].split('.jpg')[0]
            self.cache[path] = os.path.join(root, f'{name}.txt')

        self.paths = list(self.cache.keys())

        self.targetHeight = targetHeight
        self.targetWidth = targetWidth
        self.numClass = numClass

        if train:
            self.transforms = T.Compose([
                T.RandomOrder([
                    T.RandomHorizontalFlip(),
                    T.RandomVerticalFlip(),
                    T.RandomSizeCrop(numClass)
                ]),
                T.Resize((targetHeight, targetWidth)),
                T.ColorJitter(brightness=.2, contrast=.1, saturation=.1,
                              hue=0),
                T.Normalize()
            ])
        else:
            self.transforms = T.Compose(
                [T.Resize((targetHeight, targetWidth)),
                 T.Normalize()])
示例#8
0
    def __init__(self, root=None, dataloader=default_loader):
        self.transform1 = transforms.Compose([
            transforms.RandomRotation(30),
            transforms.Resize([256, 256]),
            transforms.RandomCrop(INPUT_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=(0.9, 1.1),
                                   contrast=(0.9, 1.1),
                                   saturation=(0.9, 1.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225)),
            transforms.RandomErasing(probability=0.5, sh=0.05)
        ])
        # 增强方法2: 关注更小的区域
        self.transform2 = transforms.Compose([
            transforms.RandomRotation(30),
            transforms.Resize([336, 336]),
            transforms.RandomCrop(INPUT_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=(0.9, 1.1),
                                   contrast=(0.9, 1.1),
                                   saturation=(0.9, 1.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225)),
            transforms.RandomErasing(probability=0.5, sh=0.05)
        ])
        self.dataloader = dataloader

        self.root = root
        with open(os.path.join(self.root, TRAIN_DATASET), 'r') as fid:
            self.imglist = fid.readlines()

        self.labels = []
        for line in self.imglist:
            image_path, label = line.strip().split()
            self.labels.append(int(label))
        self.labels = np.array(self.labels)
        self.labels = torch.LongTensor(self.labels)
示例#9
0
def get_dataloader():
    # TODO(xwd): Adaptive normalization by some large image.
    # E.g. In medical image processing, WSI image is very large and different to ordinary images.

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandomScale([cfg['scale_min'], cfg['scale_max']]),
        transform.RandomRotate([cfg['rotate_min'], cfg['rotate_max']],
                               padding=mean,
                               ignore_label=cfg['ignore_label']),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontallyFlip(),
        transform.RandomCrop([cfg['train_h'], cfg['train_w']],
                             crop_type='rand',
                             padding=mean,
                             ignore_label=cfg['ignore_label']),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    train_data = cityscapes.Cityscapes(cfg['data_path'],
                                       split='train',
                                       transform=train_transform)

    # Use data sampler to make sure each GPU loads specific parts of dataset to avoid data reduntant.
    train_sampler = DistributedSampler(train_data)

    train_loader = DataLoader(train_data,
                              batch_size=cfg['batch_size'] //
                              cfg['world_size'],
                              shuffle=(train_sampler is None),
                              num_workers=4,
                              pin_memory=True,
                              sampler=train_sampler,
                              drop_last=True)

    return train_loader, train_sampler
示例#10
0
    def __init__(self, root=None, dataloader=default_loader):
        self.transform = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomCrop(INPUT_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ])
        self.dataloader = dataloader

        self.root = root
        self.imgs = []
        self.labels = []

        with open(os.path.join(self.root, EVAL_DATASET), 'r') as fid:
            for line in fid.readlines():
                img_path, label = line.strip().split()
                img = self.dataloader(img_path)
                label = int(label)
                self.imgs.append(img)
                self.labels.append(label)
示例#11
0
def deploy(path):
    assert os.path.exists(path), f'{path} not found : ('
    dataset = 'YOUR_DATASET_NAME'

    img_size = 256
    test_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    testA = ImageFolder(os.path.join('dataset', dataset, 'testA'), test_transform)
    with fluid.dygraph.guard(): 
        testA_loader = DataLoader(testA, batch_size=1, shuffle=False)
        real_A, _ = next(iter(testA_loader))
        in_np = real_A.numpy()

    # load model
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    program, feed_vars, fetch_vars = fluid.io.load_inference_model(path, exe)

    # inference
    fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
    def img_postprocess(img):
        assert isinstance(img, np.ndarray), type(img)
        img = img * 0.5 + 0.5
        img = img.squeeze(0).transpose((1, 2, 0))
        # BGR to RGB
        img = img[:, :, ::-1]
        return img
    in_img = img_postprocess(in_np)
    out_img = img_postprocess(fetch)
    plt.subplot(121)
    plt.title('real A')
    plt.imshow(in_img)
    plt.subplot(122)
    plt.title('A to B')
    plt.imshow(out_img)
    plt.show()
示例#12
0
def main(args):
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    #     torch.cuda.clear_memory_allocated()
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary(device=0, abbreviated=False))
    print(torch.cuda.list_gpu_processes(device=None))
    logger = SummaryWriter()
    model = InfoNCE(args.net, args.moco_dim, args.moco_k, args.moco_m,
                    args.moco_t).cuda()
    transform = get_transform(args)
    ds = YT8M_Single_Modality('youtube8m_flow/', 32, transform)

    print(ds[0].shape)
    params = []
    for name, param in model.named_parameters():
        params.append({'params': param})
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    dl = get_dataloader(ds, args)
    transform_train_cuda = transforms.Compose([
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    channel=1)
    ])

    for ep in range(100000):
        print('epoch', ep)
        train_one_epoch(dl, model, criterion, optimizer, transform_train_cuda,
                        1, args, logger)
        save_dict = {
            'epoch': ep,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'iteration': args.iteration
        }
        torch.save(save_dict, 'flow_nce/ckpt.pth.tar')
示例#13
0
def main(args=None):
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    # torch.cuda.set_device(0)
    use_gpu = torch.cuda.is_available()

    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, args.log_train))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, args.log_test))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_dataset(name=args.dataset)

    transform_train = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(),
    ])

    transform_train2 = T.Compose([
        T.Resize((args.height, args.width)),
        T.Random2DTranslation(args.height, args.width),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = False

    trainloader = DataLoader(
        VideoDataset(dataset.train,
                     data_name=args.dataset,
                     seq_len=args.seq_len,
                     sample='random',
                     transform=transform_train,
                     transform2=transform_train2,
                     type="train"),
        sampler=RandomIdentitySampler(dataset.train,
                                      num_instances=args.num_instances),
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        VideoDataset(dataset.query,
                     data_name=args.dataset,
                     seq_len=args.seq_len,
                     sample='dense',
                     transform=transform_test,
                     type="test"),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     data_name=args.dataset,
                     seq_len=args.seq_len,
                     sample='dense',
                     transform=transform_test,
                     type="test"),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing models: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              final_dim=args.feat_dim)
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    crossEntropyLoss = CrossEntropyLabelSmooth(
        num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    tripletLoss = TripletLoss(margin=args.margin)
    regularLoss = RegularLoss(use_gpu=use_gpu)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    scheduler = WarmupMultiStepLR(optimizer, args.stepsize, args.gamma,
                                  args.warmup_factor, args.warmup_items,
                                  args.warmup_method)
    start_epoch = args.start_epoch

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        atest(model, queryloader, galleryloader, use_gpu)
        return

    start_time = time.time()
    best_rank1 = -np.inf
    for epoch in range(start_epoch, args.max_epoch):
        print("==> Epoch {}/{}".format(epoch + 1, args.max_epoch))

        train(model, crossEntropyLoss, tripletLoss, regularLoss, optimizer,
              trainloader, use_gpu)

        # if args.stepsize > 0:
        scheduler.step()

        if (epoch + 1) >= 200 and (epoch + 1) % args.eval_step == 0:
            print("==> Test")
            rank1 = atest(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1
            if is_best: best_rank1 = rank1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint({
                'state_dict': state_dict,
            }, is_best,
                            osp.join(
                                args.save_dir,
                                args.model_name + str(epoch + 1) + '.pth.tar'))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
示例#14
0
def main(args=None):
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    # torch.cuda.set_device(0)
    use_gpu = torch.cuda.is_available()

    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, args.log_train))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, args.log_test))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_dataset(name=args.dataset)

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = False

    queryloader = DataLoader(
        VideoDataset(dataset.query,
                     data_name=args.dataset,
                     seq_len=args.seq_len,
                     sample='dense',
                     transform=transform_test,
                     type="test"),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     data_name=args.dataset,
                     seq_len=args.seq_len,
                     sample='dense',
                     transform=transform_test,
                     type="test"),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing models: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              final_dim=args.feat_dim)
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['state_dict'])

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        atest(model, queryloader, galleryloader, use_gpu)
        return
示例#15
0
def main():
    net = AFENet(classes=21,
                 pretrained_model_path=args['pretrained_model_path']).cuda()
    net_ori = [net.layer0, net.layer1, net.layer2, net.layer3, net.layer4]
    net_new = [
        net.ppm, net.cls, net.aux, net.ppm_reduce, net.aff1, net.aff2, net.aff3
    ]

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([0.75, 2.0]),
        transform.RandomHorizontalFlip(),
        transform.Crop([480, 480],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=255),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    train_data = voc2012.VOC2012(split='train',
                                 data_root=args['dataset_root'],
                                 data_list=args['train_list'],
                                 transform=train_transform)
    train_loader = DataLoader(train_data,
                              batch_size=args['train_batch_size'],
                              shuffle=True,
                              num_workers=8,
                              drop_last=True)

    val_transform = transform.Compose([
        transform.Crop([480, 480],
                       crop_type='center',
                       padding=mean,
                       ignore_label=255),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])
    val_data = voc2012.VOC2012(split='val',
                               data_root=args['dataset_root'],
                               data_list=args['val_list'],
                               transform=val_transform)

    val_loader = DataLoader(val_data,
                            batch_size=args['val_batch_size'],
                            shuffle=False,
                            num_workers=8)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(args['model_save_path'],
                                    args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    params_list = []
    for module in net_ori:
        params_list.append(dict(params=module.parameters(), lr=args['lr']))
    for module in net_new:
        params_list.append(dict(params=module.parameters(),
                                lr=args['lr'] * 10))
    args['index_split'] = 5

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    optimizer = torch.optim.SGD(params_list,
                                lr=args['lr'],
                                momentum=args['momentum'],
                                weight_decay=args['weight_decay'])
    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(args['model_save_path'],
                             'opt_' + args['snapshot'])))

    check_makedirs(args['model_save_path'])

    all_iter = args['epoch_num'] * len(train_loader)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, optimizer, epoch, all_iter)
        validate(val_loader, net, criterion, optimizer, epoch)
示例#16
0
def main_worker(gpu, ngpus_per_node, args):
    best_acc = 0
    args.gpu = gpu

    if args.distributed:
        if args.local_rank != -1:
            args.rank = args.local_rank
            args.gpu = args.local_rank
        elif 'SLURM_PROCID' in os.environ: # slurm scheduler
            args.rank = int(os.environ['SLURM_PROCID'])
            args.gpu = args.rank % torch.cuda.device_count()
        elif args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
            if args.multiprocessing_distributed:
                # For multiprocessing distributed training, rank needs to be the
                # global rank among all the processes
                args.rank = args.rank * ngpus_per_node + gpu
        
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    args.print = args.gpu == 0
    # suppress printing if not master
    if (args.multiprocessing_distributed and args.gpu != 0) or\
       (args.local_rank != -1 and args.gpu != 0) or\
       ('SLURM_PROCID' in os.environ and args.rank!=0):
        def print_pass(*args):
            pass
        builtins.print = print_pass

    ### model ###
    print("=> creating {} model with '{}' backbone".format(args.model, args.net))
    if args.model == 'coclr':
        model = CoCLR(args.net, args.moco_dim, args.moco_k, args.moco_m, args.moco_t, topk=args.topk, reverse=args.reverse)
        if args.reverse:
            print('[Warning] using RGB-Mining to help flow')
        else:
            print('[Warning] using Flow-Mining to help RGB')
    else:
        raise NotImplementedError
    args.num_seq = 2
    print('Re-write num_seq to %d' % args.num_seq)
        
    args.img_path, args.model_path, args.exp_path = set_path(args)

    # print(model)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            model_without_ddp = model.module
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
            model_without_ddp = model.module
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        # raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")


    ### optimizer ###
    params = []
    for name, param in model.named_parameters():
        params.append({'params': param})

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        print(name, param.requires_grad)
    print('=================================\n')

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    args.iteration = 1

    ### data ###  
    transform_train = get_transform('train', args)
    train_loader = get_dataloader(get_data(transform_train, 'train', args), 'train', args)
    transform_train_cuda = transforms.Compose([
                T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225], channel=1)])
    n_data = len(train_loader.dataset)

    print('===================================')

    lr_scheduler = None

    ### restart training ### 
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']+1
            args.iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            state_dict = checkpoint['state_dict']

            try: model_without_ddp.load_state_dict(state_dict)
            except: 
                print('[WARNING] Non-Equal load for resuming training!')
                neq_load_customized(model_without_ddp, state_dict, verbose=True)

            print("=> load resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
            try: optimizer.load_state_dict(checkpoint['optimizer'])
            except: print('[WARNING] Not loading optimizer states')
        else:
            print("[Warning] no checkpoint found at '{}', use random init".format(args.resume))

    elif args.pretrain != ['random', 'random']:
        # first path: weights to be trained
        # second path: weights as the oracle, not trained
        if os.path.isfile(args.pretrain[1]): # second network --> load as sampler
            checkpoint = torch.load(args.pretrain[1], map_location=torch.device('cpu'))
            second_dict = checkpoint['state_dict']
            new_dict = {}
            for k,v in second_dict.items(): # only take the encoder_q
                if 'encoder_q.' in k:
                    k = k.replace('encoder_q.', 'sampler.')
                    new_dict[k] = v
            second_dict = new_dict

            new_dict = {} # remove queue, queue_ptr
            for k, v in second_dict.items():
                if 'queue' not in k:
                    new_dict[k] = v 
            second_dict = new_dict
            print("=> Use Oracle checkpoint '{}' (epoch {})".format(args.pretrain[1], checkpoint['epoch']))
        else:
            print("=> NO Oracle checkpoint found at '{}', use random init".format(args.pretrain[1]))
            second_dict = {}

        if os.path.isfile(args.pretrain[0]): # first network --> load both encoder q & k
            checkpoint = torch.load(args.pretrain[0], map_location=torch.device('cpu'))
            first_dict = checkpoint['state_dict']

            new_dict = {} # remove queue, queue_ptr
            for k, v in first_dict.items():
                if 'queue' not in k:
                    new_dict[k] = v 
            first_dict = new_dict

            # update both q and k with q
            new_dict = {}
            for k,v in first_dict.items(): # only take the encoder_q
                if 'encoder_q.' in k:
                    new_dict[k] = v
                    k = k.replace('encoder_q.', 'encoder_k.')
                    new_dict[k] = v
            first_dict = new_dict
            
            print("=> Use Training checkpoint '{}' (epoch {})".format(args.pretrain[0], checkpoint['epoch']))
        else:
            print("=> NO Training checkpoint found at '{}', use random init".format(args.pretrain[0]))
            first_dict = {}

        state_dict = {**first_dict, **second_dict}
        try:
            del state_dict['queue_label'] # always re-fill the queue
        except:
            pass 
        neq_load_customized(model_without_ddp, state_dict, verbose=True)

    else:
        print("=> train from scratch")

    torch.backends.cudnn.benchmark = True

    # tensorboard plot tools
    writer_train = SummaryWriter(logdir=os.path.join(args.img_path, 'train'))
    args.train_plotter = TB.PlotterThread(writer_train)
    
    ### main loop ###    
    for epoch in range(args.start_epoch, args.epochs):
        np.random.seed(epoch)
        random.seed(epoch)
        
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        _, train_acc = train_one_epoch(train_loader, model, criterion, optimizer, transform_train_cuda, epoch, args)
        if (epoch % args.save_freq == 0) or (epoch == args.epochs - 1):         
            # save check_point on rank==0 worker
            if (not args.multiprocessing_distributed and args.rank == 0) \
                or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                is_best = train_acc > best_acc
                best_acc = max(train_acc, best_acc)
                state_dict = model_without_ddp.state_dict()
                save_dict = {
                    'epoch': epoch,
                    'state_dict': state_dict,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'iteration': args.iteration}
                save_checkpoint(save_dict, is_best, gap=args.save_freq, 
                    filename=os.path.join(args.model_path, 'epoch%d.pth.tar' % epoch), 
                    keep_all='k400' in args.dataset)
    
    print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs))
    sys.exit(0)
示例#17
0
    def build_model(self):
        """ DataLoader """
        pad = int(30 * self.img_size // 256)
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.img_size + pad, self.img_size + pad)),
            transforms.RandomCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.trainA = ImageFolder(
            os.path.join('dataset', self.dataset, 'trainA'), train_transform)
        self.trainB = ImageFolder(
            os.path.join('dataset', self.dataset, 'trainB'), train_transform)
        self.testA = ImageFolder(
            os.path.join('dataset', self.dataset, 'testA'), test_transform)
        self.testB = ImageFolder(
            os.path.join('dataset', self.dataset, 'testB'), test_transform)
        self.trainA_loader = DataLoader(self.trainA,
                                        batch_size=self.batch_size,
                                        shuffle=True)
        self.trainB_loader = DataLoader(self.trainB,
                                        batch_size=self.batch_size,
                                        shuffle=True)
        self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
        self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)
        """ Define Generator, Discriminator """
        self.genA2B = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.genB2A = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light)
        self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
        self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
        self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
        self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
        """ Define Loss """
        self.L1_loss = loss.L1Loss()
        self.MSE_loss = loss.MSELoss()
        self.BCE_loss = loss.BCEWithLogitsLoss()
        """ Trainer """
        def get_params(block):
            out = []
            for name, param in block.named_parameters():
                if 'instancenorm' in name or 'weight_u' in name or 'weight_v' in name:
                    continue
                out.append(param)
            return out

        genA2B_parameters = get_params(self.genA2B)
        genB2A_parameters = get_params(self.genB2A)
        disGA_parameters = get_params(self.disGA)
        disGB_parameters = get_params(self.disGB)
        disLA_parameters = get_params(self.disLA)
        disLB_parameters = get_params(self.disLB)
        G_parameters = genA2B_parameters + genB2A_parameters
        D_parameters = disGA_parameters + disGB_parameters + disLA_parameters + disLB_parameters
        self.G_optim = fluid.optimizer.Adam(
            parameter_list=G_parameters,
            learning_rate=self.lr,
            beta1=0.5,
            beta2=0.999,
            regularization=fluid.regularizer.L2Decay(self.weight_decay))
        self.D_optim = fluid.optimizer.Adam(
            parameter_list=D_parameters,
            learning_rate=self.lr,
            beta1=0.5,
            beta2=0.999,
            regularization=fluid.regularizer.L2Decay(self.weight_decay))
        """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
示例#18
0
def Dataset_Loader(configer):

    # This helper function loads the config dataset and created the torch Dataloader

    data_configer = configer.dataset_cfg
    Dataset_name = data_configer['id_cfg']['name']
    Data_root = data_configer['id_cfg']['root']
    Data_download = data_configer['id_cfg']['download']

    Model = configer.model["name"]

    ####### Image Transforms builder

    if Model in ["Inceptionv3", "Xception"]:
        if Dataset_name == "Imagenet":
            img_transform = transforms.Compose([
                Imagenet_trans.ToTensor(),
                Imagenet_trans.CenterCrop((299, 299))
            ])
        else:
            img_transform = transforms.Compose(
                [transforms.Resize((299, 299)),
                 transforms.ToTensor()])

    elif Model in [
            "Densenet121", "VGG_19", "Resnet18", "Resnet50", "Resnet34",
            "Resnet101", "Resnet152", "ResNeXt101-32", "ResNeXt101-64"
    ]:

        if Dataset_name == "Imagenet":
            img_transform = transforms.Compose([
                Imagenet_trans.ToTensor(),
                Imagenet_trans.CenterCrop((224, 224))
            ])
        elif Dataset_name == "Caltech":
            img_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ])
        else:
            img_transform = transforms.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])

    elif Model in ["MobilenetV2"]:
        if Dataset_name == "Imagenet":
            img_transform = transforms.Compose([
                Imagenet_trans.ToTensor(),
                Imagenet_trans.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225]),
                Imagenet_trans.CenterCrop((224, 224))
            ])
        else:
            img_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ])

    else:
        raise ImportError("DL model architecture not supported")

    ####### Dataset train and test builder

    if Dataset_name == "MNIST":  # Shape: (1,28,28)

        if Data_download:

            Trainloader = Data.MNIST(Data_root,
                                     download=True,
                                     train=True,
                                     transform=img_transform)
            Testloader = Data.MNIST(Data_root,
                                    download=True,
                                    train=False,
                                    transform=img_transform)
        else:

            Trainloader = Data.MNIST(os.path.join(Data_root, "MNIST"),
                                     download=False,
                                     train=True,
                                     transform=img_transform)
            Testloader = Data.MNIST(os.path.join(Data_root, "MNIST"),
                                    download=False,
                                    train=False,
                                    transform=img_transform)

    elif Dataset_name == "CIFAR10":  # Shape: (3,32,32)

        if Data_download:

            Trainloader = Data.CIFAR10(Data_root,
                                       download=True,
                                       train=True,
                                       transform=img_transform)
            Testloader = Data.CIFAR10(Data_root,
                                      download=True,
                                      train=False,
                                      transform=img_transform)

        else:

            Trainloader = Data.CIFAR10(os.path.join(Data_root),
                                       download=False,
                                       train=True,
                                       transform=img_transform)
            Testloader = Data.CIFAR10(os.path.join(Data_root),
                                      download=False,
                                      train=False,
                                      transform=img_transform)

    elif Dataset_name == "CIFAR100":  # Shape: (3,32,32)

        if Data_download:

            Trainloader = Data.CIFAR100(Data_root,
                                        download=True,
                                        train=True,
                                        transform=img_transform)
            Testloader = Data.CIFAR100(Data_root,
                                       download=True,
                                       train=False,
                                       transform=img_transform)

        else:

            Trainloader = Data.CIFAR100(Data_root,
                                        download=False,
                                        train=True,
                                        transform=img_transform)
            Testloader = Data.CIFAR100(Data_root,
                                       download=False,
                                       train=False,
                                       transform=img_transform)

    elif Dataset_name == "Fashion-MNIST":

        if Data_download:

            Trainloader = Data.FashionMNIST(Data_root,
                                            download=True,
                                            train=True,
                                            transform=img_transform)
            Testloader = Data.FashionMNIST(Data_root,
                                           download=True,
                                           train=False,
                                           transform=img_transform)

        else:

            Trainloader = Data.FashionMNIST(os.path.join(
                Data_root, "Fashion-MNIST"),
                                            download=False,
                                            train=True,
                                            transform=img_transform)
            Testloader = Data.FashionMNIST(os.path.join(
                Data_root, "Fashion-MNIST"),
                                           download=False,
                                           train=False,
                                           transform=img_transform)

    elif Dataset_name == "SVHN":

        if Data_download:

            Trainloader = Data.SVHN(Data_root,
                                    download=True,
                                    split="train",
                                    transform=img_transform)
            Testloader = Data.SVHN(Data_root,
                                   download=True,
                                   split="test",
                                   transform=img_transform)

        else:

            Trainloader = Data.SVHN(os.path.join(Data_root, "SVHN"),
                                    download=False,
                                    split="train",
                                    transform=img_transform)
            Testloader = Data.SVHN(os.path.join(Data_root, "SVHN"),
                                   download=False,
                                   split="test",
                                   transform=img_transform)

    elif Dataset_name == "STL10":

        if Data_download:

            Trainloader = Data.STL10(os.path.join(Data_root),
                                     download=True,
                                     split="train",
                                     transform=img_transform)
            Testloader = Data.STL10(os.path.join(Data_root),
                                    download=True,
                                    split="test",
                                    transform=img_transform)

        else:

            Trainloader = Data.STL10(os.path.join(Data_root),
                                     download=False,
                                     split="train",
                                     transform=img_transform)
            Testloader = Data.STL10(os.path.join(Data_root),
                                    download=False,
                                    split="test",
                                    transform=img_transform)

    elif Dataset_name == "Caltech":

        if Data_download:

            if not os.path.isdir(os.path.join(Data_root, "Caltech")):
                os.mkdir(os.path.join(Data_root, "Caltech"))

            Trainloader = Caltech256(os.path.join(Data_root, "Caltech"),
                                     download=True,
                                     train=True,
                                     transform=img_transform)
            Testloader = Caltech256(os.path.join(Data_root, "Caltech"),
                                    download=True,
                                    train=False,
                                    transform=img_transform)

        else:

            Trainloader = Caltech256(os.path.join(Data_root, "Caltech"),
                                     train=True,
                                     transform=img_transform)
            Testloader = Caltech256(os.path.join(Data_root, "Caltech"),
                                    train=False,
                                    transform=img_transform)

    elif Dataset_name == "Imagenet":

        Trainloader = ImageFolder(data_path=Data_root, transform=img_transform)
        Testloader = ImageFolder(data_path=Data_root,
                                 transform=img_transform,
                                 Train=False)

    else:
        raise ImportError("Dataset not supported")

    # Creating train and test loaders

    Train_configer = data_configer['train_cfg']
    Val_configer = data_configer['val_cfg']

    train_loader = TD.DataLoader(dataset=Trainloader,
                                 batch_size=Train_configer['batch_size'],
                                 shuffle=Train_configer['shuffle'])
    #num_workers= Train_configer['num_workers'],
    #pin_memory=True)

    test_loader = TD.DataLoader(dataset=Testloader,
                                batch_size=Val_configer['batch_size'],
                                shuffle=Val_configer['shuffle'])
    #num_workers= Val_configer['num_workers'],
    #pin_memory=True)

    print('---------- Training and Test data Loaded ')

    return train_loader, test_loader
示例#19
0
from model.deeplab import DeepLab
import matplotlib.pyplot as plt
from tools import prediction
from utils.metrics import Evaluator
args = get_args()
rng = np.random.RandomState(seed=args.seed)

torch.manual_seed(seed=args.seed)

transform_train = trans.Compose([
    trans.RandomHorizontalFlip(),
    #trans.FixScale((args.crop_size,args.crop_size)),
    trans.RandomScale((0.5, 2.0)),
    #trans.FixScale(args.crop_size),
    trans.RandomCrop(args.crop_size),
    trans.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    trans.ToTensor(),
])

transform_val = trans.Compose([
    #trans.FixScale((args.crop_size,args.crop_size)),
    trans.FixScale(args.crop_size),
    trans.CenterCrop(args.crop_size),
    trans.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    trans.ToTensor(),
])
if (args.aug == True):
    voc_train = VOCSegmentation(root='./data',
                                set_name='train',
                                transform=transform_train)
else:
示例#20
0
#dataset prepare
#---------------------------------
print('Loading dataset...')
cache_size = 256
if args.image_size == 448:
    cache_size = 256 * 2
if args.image_size == 352:
    cache_size = 402
transform_train = transforms.Compose([
    transforms.Resize((cache_size,cache_size)),
    #transforms.Resize((args.image_size,args.image_size)),
    #transforms.RandomRotation(10),
    transforms.RandomCrop(args.image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])

transform_test = transforms.Compose([
    transforms.Resize((cache_size,cache_size)),
    transforms.CenterCrop(args.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])


print("Dataset Initializing...")
trainset = SRDataset.SRDataset(max_person=args.max_person,image_dir=args.images_root, \
    images_list=args.train_file_pre + '_images.txt',bboxes_list=args.train_file_pre + '_bbox.json', \
    relations_list=args.train_file_pre + '_relation.json', image_size=args.image_size,input_transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,worker_init_fn=np.random.seed(args.manualSeed))
示例#21
0
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_imgreid_dataset(
        root=args.root, name=args.dataset,
    )

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        sampler=RandomIdentitySampler(dataset.train, num_instances=args.num_instances),
        batch_size=args.train_batch, num_workers=args.workers,
        pin_memory=pin_memory, drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.test_3000_query, transform=transform_test),
        batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
        pin_memory=pin_memory, drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.test_3000_gallery, transform=transform_test),
        batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
        pin_memory=pin_memory, drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch, num_classes=dataset.train_vehicle_nums, loss={'xent', 'htri'})
    print("Model size: {:.3f} M".format(sum(p.numel() for p in model.parameters()) / 1000000.0))
    modelid = 0
    if args.upload:
        modelid = upload_data.updatemodel(args.arch, args.lr, args.stepsize, args.gamma, args.loss, args.dataset,
                                          args.height, args.width, args.seq_len, args.train_batch, args.other)

    # criterion_xent = CrossEntropyLabelSmooth(num_classes=dataset.train_vehicle_nums, use_gpu=use_gpu)
    criterion_xent = nn.CrossEntropyLoss()
    criterion_htri = TripletLoss(margin=args.margin)

    optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
    start_epoch = args.start_epoch

    if args.load_weights:
        # load pretrained weights but ignore layers that don't match in size
        print("Loading pretrained weights from '{}'".format(args.load_weights))
        checkpoint = torch.load(args.load_weights)
        pretrain_dict = checkpoint['state_dict']
        model_dict = model.state_dict()
        pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)

    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
        rank1 = checkpoint['rank1']
        print("Loaded checkpoint from '{}'".format(args.resume))
        print("- start_epoch: {}\n- rank1: {}".format(start_epoch, rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        test(model, queryloader, galleryloader, use_gpu, return_distmat=True)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(modelid, epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                epoch + 1) == args.max_epoch:
            print("==> Test")
            cmc, mAP = test(model, queryloader, galleryloader, use_gpu)
            rank1 = cmc[0]
            is_best = rank1 > best_rank1

            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'rank1': rank1,
                'epoch': epoch,
            }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
            if args.upload:
                upload_data.updatetest(modelid, epoch + 1, mAP.item(), cmc[0].item(), cmc[4].item(), cmc[9].item(),
                                       cmc[19].item())
                upload_data.updaterank(modelid, best_rank1.item())

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
def main():
    """Create the model and start the training."""
    or_nyu_dict = {
        0: 255,
        1: 16,
        2: 40,
        3: 39,
        4: 7,
        5: 14,
        6: 39,
        7: 12,
        8: 38,
        9: 40,
        10: 10,
        11: 6,
        12: 40,
        13: 39,
        14: 39,
        15: 40,
        16: 18,
        17: 40,
        18: 4,
        19: 40,
        20: 40,
        21: 5,
        22: 40,
        23: 40,
        24: 30,
        25: 36,
        26: 38,
        27: 40,
        28: 3,
        29: 40,
        30: 40,
        31: 9,
        32: 38,
        33: 40,
        34: 40,
        35: 40,
        36: 34,
        37: 37,
        38: 40,
        39: 40,
        40: 39,
        41: 8,
        42: 3,
        43: 1,
        44: 2,
        45: 22
    }
    or_nyu_map = lambda x: or_nyu_dict.get(x, x) - 1
    or_nyu_map = np.vectorize(or_nyu_map)

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    args.or_nyu_map = or_nyu_map
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from == "":
            saved_state_dict = None
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 40 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True
    if args.mode != "baseline" and args.mode != "baseline_tar":
        # init D
        model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
        model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

        model_D1.train()
        model_D1.to(device)

        model_D2.train()
        model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    scale_min = 0.5
    scale_max = 2.0
    rotate_min = -10
    rotate_max = 10
    ignore_label = 255
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    args.width = w
    args.height = h
    train_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.RandScale([scale_min, scale_max]),
        transforms.RandRotate([rotate_min, rotate_max],
                              padding=IMG_MEAN_RGB,
                              ignore_label=ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='rand',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        #et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
        #et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        #et.ExtRandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])

    val_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='center',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])
    if args.mode != "baseline_tar":
        src_train_dst = OpenRoomsSegmentation(root=args.data_dir,
                                              opt=args,
                                              split='train',
                                              transform=train_transform,
                                              imWidth=args.width,
                                              imHeight=args.height,
                                              remap_labels=args.or_nyu_map)
    else:
        src_train_dst = NYU_Labelled(root=args.data_dir_target,
                                     opt=args,
                                     split='train',
                                     transform=train_transform,
                                     imWidth=args.width,
                                     imHeight=args.height,
                                     phase="TRAIN",
                                     randomize=True)
    tar_train_dst = NYU(root=args.data_dir_target,
                        opt=args,
                        split='train',
                        transform=train_transform,
                        imWidth=args.width,
                        imHeight=args.height,
                        phase="TRAIN",
                        randomize=True,
                        mode=args.mode)
    tar_val_dst = NYU(root=args.data_dir,
                      opt=args,
                      split='val',
                      transform=val_transform,
                      imWidth=args.width,
                      imHeight=args.height,
                      phase="TRAIN",
                      randomize=False)
    trainloader = data.DataLoader(src_train_dst,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(tar_train_dst,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    if args.mode != "baseline" and args.mode != "baseline_tar":
        optimizer_D1 = optim.Adam(model_D1.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D1.zero_grad()

        optimizer_D2 = optim.Adam(model_D2.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1] + 1, input_size[0] + 1),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1] + 1,
                                      input_size_target[0] + 1),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_seg_value1_tar = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_seg_value2_tar = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.zero_grad()
            optimizer_D2.zero_grad()
            adjust_learning_rate_D(optimizer_D1, i_iter)
            adjust_learning_rate_D(optimizer_D2, i_iter)
        sample_src = None
        sample_tar = None
        sample_res_src = None
        sample_res_tar = None
        sample_gt_src = None
        sample_gt_tar = None
        for sub_i in range(args.iter_size):

            # train G
            if args.mode != "baseline" and args.mode != "baseline_tar":
                # don't accumulate grads in D
                for param in model_D1.parameters():
                    param.requires_grad = False

                for param in model_D2.parameters():
                    param.requires_grad = False

            # train with source
            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()
            images, labels, _ = batch
            sample_src = images.clone()
            sample_gt_src = labels.clone()

            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            sample_pred_src = pred2.detach().cpu()

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target
            try:
                _, batch = targetloader_iter.__next__()
            except:
                targetloader_iter = enumerate(targetloader)
                _, batch = targetloader_iter.__next__()
            images, tar_labels, _, labelled = batch
            n_labelled = labelled.sum().detach().item()
            batch_size = images.shape[0]
            sample_tar = images.clone()
            sample_gt_tar = tar_labels.clone()
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            #print("N_labelled {}".format(n_labelled))
            if args.mode == "sda" and n_labelled != 0:
                labelled = labelled.to(device) == 1
                tar_labels = tar_labels.to(device)
                loss_seg1_tar = seg_loss(pred_target1[labelled],
                                         tar_labels[labelled])
                loss_seg2_tar = seg_loss(pred_target2[labelled],
                                         tar_labels[labelled])
                loss_tar_labelled = loss_seg2_tar + args.lambda_seg * loss_seg1_tar
                loss_tar_labelled = loss_tar_labelled / args.iter_size
                loss_seg_value1_tar += loss_seg1_tar.item() / args.iter_size
                loss_seg_value2_tar += loss_seg2_tar.item() / args.iter_size
            else:
                loss_tar_labelled = torch.zeros(
                    1, requires_grad=True).float().to(device)
            # proper normalization
            sample_pred_tar = pred_target2.detach().cpu()
            if args.mode != "baseline" and args.mode != "baseline_tar":
                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size + loss_tar_labelled
                #loss = loss_tar_labelled
                loss.backward()
                loss_adv_target_value1 += loss_adv_target1.item(
                ) / args.iter_size
                loss_adv_target_value2 += loss_adv_target2.item(
                ) / args.iter_size
                # train D

                # bring back requires_grad
                for param in model_D1.parameters():
                    param.requires_grad = True

                for param in model_D2.parameters():
                    param.requires_grad = True

                # train with source
                pred1 = pred1.detach()
                pred2 = pred2.detach()

                D_out1 = model_D1(F.softmax(pred1))
                D_out2 = model_D2(F.softmax(pred2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

                # train with target
                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()

                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

        optimizer.step()
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.step()
            optimizer_D2.step()
        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
                'loss_seg1_tar': loss_seg_value1_tar,
                'loss_seg2_tar': loss_seg_value2_tar,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)
            if i_iter % 1000 == 0:
                img = sample_src.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy(
                    np.array(IMG_MEAN_RGB).reshape(1, 3, 1, 1)).float()
                img = img.type(torch.uint8)
                writer.add_images("Src/Images", img, i_iter)
                label = tar_train_dst.decode_target(sample_gt_src).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Labels", label, i_iter)
                preds = sample_pred_src.permute(0, 2, 3, 1).cpu().numpy()
                preds = np.asarray(np.argmax(preds, axis=3), dtype=np.uint8)
                preds = tar_train_dst.decode_target(preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Preds", preds, i_iter)

                tar_img = sample_tar.cpu()[:,
                                           [2, 1, 0], :, :] + torch.from_numpy(
                                               np.array(IMG_MEAN_RGB).reshape(
                                                   1, 3, 1, 1)).float()
                tar_img = tar_img.type(torch.uint8)
                writer.add_images("Tar/Images", tar_img, i_iter)
                tar_label = tar_train_dst.decode_target(
                    sample_gt_tar).transpose(0, 3, 1, 2)
                writer.add_images("Tar/Labels", tar_label, i_iter)
                tar_preds = sample_pred_tar.permute(0, 2, 3, 1).cpu().numpy()
                tar_preds = np.asarray(np.argmax(tar_preds, axis=3),
                                       dtype=np.uint8)
                tar_preds = tar_train_dst.decode_target(tar_preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Tar/Preds", tar_preds, i_iter)
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_seg1_tar={8:.3f} loss_seg2_tar={9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_seg_value1_tar,
                    loss_seg_value2_tar))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'OR_' + str(args.num_steps_stop) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    nyu_nyu_dict = {11:255, 13:255, 15:255, 17:255, 19:255, 20:255, 21: 255, 23: 255, 
            24:255, 25:255, 26:255, 27:255, 28:255, 29:255, 31:255, 32:255, 33:255}
    nyu_nyu_map = lambda x: nyu_nyu_dict.get(x+1,x)
    nyu_nyu_map = np.vectorize(nyu_nyu_map)
    args.nyu_nyu_map = nyu_nyu_map
    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    metrics = StreamSegMetrics(args.num_classes)
    metrics_remap = StreamSegMetrics(args.num_classes)
    ignore_label = 255
    value_scale = 255 
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    val_transform = transforms.Compose([
	    # et.ExtResize( 512 ),
	    transforms.Crop([args.height+1, args.width+1], crop_type='center', padding=IMG_MEAN, ignore_label=ignore_label),
	    transforms.ToTensor(),
	    transforms.Normalize(mean=IMG_MEAN,
	    	    std=[1, 1, 1]),
	])
    val_dst = NYU(root=args.data_dir, opt=args,
			 split='val', transform=val_transform,
			 imWidth = args.width, imHeight = args.height, phase="TEST",
			 randomize = False)
    print("Dset Length {}".format(len(val_dst)))
    testloader = data.DataLoader(val_dst,
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(args.height+1, args.width+1), mode='bilinear', align_corners=True)
    metrics.reset()
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, targets, name = batch
        image = image.to(device)
        print(index)
        if args.model == 'DeeplabMulti':
            output1, output2 = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()
        targets = targets.cpu().numpy()
        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        preds = output[None,:,:]
        #input_ = image.cpu().numpy()[0].transpose(1,2,0) + np.array(IMG_MEAN)
        metrics.update(targets, preds)
        targets = args.nyu_nyu_map(targets)
        preds = args.nyu_nyu_map(preds)
        metrics_remap.update(targets,preds)
        #input_ = Image.fromarray(input_.astype(np.uint8))
        #output_col = colorize_mask(output)
        #output = Image.fromarray(output)
        
        #name = name[0].split('/')[-1]
        #input_.save('%s/%s' % (args.save, name))
        #output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
    print(metrics.get_results())
    print(metrics_remap.get_results())
示例#24
0
from xt_training import metrics

from utils import transforms as xt_transforms
from datasets.tms_datasets import TMSDataset, loader_binary, loader_csv
from models.resnet_1d import resnet18, resnet50
from models.resnet_2d import Resnet, ResnetThreshold


# Transforms
transforms_q = transforms.Compose([
    xt_transforms.GetS3Mask(0.0),
    # xt_transforms.LowPass(13, pad_length=192),
    xt_transforms.SwapPillars(),
    # xt_transforms.ReduceToPoles(),
    # xt_transforms.AddSensorRatio(),
    xt_transforms.Normalize(),
    xt_transforms.FromNumpy()
])
transforms_k = transforms.Compose([
    xt_transforms.GetS3Mask(0.0),
    # xt_transforms.LowPass(13, pad_length=192),
    # xt_transforms.SwapPillars(),
    # xt_transforms.ReduceToPoles(),
    # xt_transforms.AddSensorRatio(),
    xt_transforms.Normalize(),
    xt_transforms.FromNumpy()
])

# Dataloader
batch_size = 256
workers = 4
示例#25
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    if args.local_rank != -1: 
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
    best_acc = 0

    args.print = args.gpu == 0
    # suppress printing if not master
    if (args.multiprocessing_distributed and args.gpu != 0) or\
       (args.local_rank != -1 and args.gpu != 0):
        def print_pass(*args):
            pass
        builtins.print = print_pass

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        if args.local_rank != -1:
            args.rank = args.local_rank
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    ### model ###
    print("=> creating {} model with '{}' backbone".format(args.model, args.net))
    if args.model == 'infonce':
        model = InfoNCE(args.net, args.moco_dim, args.moco_k, args.moco_m, args.moco_t)
    elif args.model == 'ubernce':
        model = UberNCE(args.net, args.moco_dim, args.moco_k, args.moco_m, args.moco_t)
    
    args.num_seq = 2
    print('Re-write num_seq to %d' % args.num_seq)

    args.img_path, args.model_path, args.exp_path = set_path(args)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            model_without_ddp = model.module
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
            model_without_ddp = model.module
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")


    ### optimizer ###
    params = []
    if args.train_what == 'all':
        for name, param in model.named_parameters():
            params.append({'params': param})
    else:
        raise NotImplementedError

    print('\n===========Check Grad============')
    for name, param in model.named_parameters():
        if not param.requires_grad:
            print(name, param.requires_grad)
    print('=================================\n')

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    args.iteration = 1

    ### data ###  
    transform_train = get_transform('train', args)
    train_loader = get_dataloader(get_data(transform_train, 'train', args), 'train', args)
    transform_train_cuda = transforms.Compose([
                T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225], channel=1)])
    n_data = len(train_loader.dataset)

    print('===================================')

    ### restart training ### 
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']+1
            args.iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            state_dict = checkpoint['state_dict']

            try: model_without_ddp.load_state_dict(state_dict)
            except: 
                print('[WARNING] resuming training with different weights')
                neq_load_customized(model_without_ddp, state_dict, verbose=True)

            print("=> load resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
            try: optimizer.load_state_dict(checkpoint['optimizer'])
            except: print('[WARNING] failed to load optimizer state, initialize optimizer')
        else:
            print("[Warning] no checkpoint found at '{}', use random init".format(args.resume))
    
    elif args.pretrain:
        if os.path.isfile(args.pretrain):
            checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu'))
            state_dict = checkpoint['state_dict']
                
            try: model_without_ddp.load_state_dict(state_dict)
            except: neq_load_customized(model_without_ddp, state_dict, verbose=True)
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(args.pretrain, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}', use random init".format(args.pretrain))
    
    else:
        print("=> train from scratch")

    torch.backends.cudnn.benchmark = True

    # tensorboard plot tools
    writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))
    args.train_plotter = TB.PlotterThread(writer_train)

    ### main loop ###    
    for epoch in range(args.start_epoch, args.epochs):
        np.random.seed(epoch)
        random.seed(epoch)
        
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        _, train_acc = train_one_epoch(train_loader, model, criterion, optimizer, lr_scheduler, transform_train_cuda, epoch, args)
        
        if (epoch % args.save_freq == 0) or (epoch == args.epochs - 1): 
            # save check_point on rank==0 worker
            if (not args.multiprocessing_distributed and args.rank == 0) \
                or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                is_best = train_acc > best_acc
                best_acc = max(train_acc, best_acc)
                state_dict = model_without_ddp.state_dict()
                save_dict = {
                    'epoch': epoch,
                    'state_dict': state_dict,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'iteration': args.iteration}
                save_checkpoint(save_dict, is_best, gap=args.save_freq, 
                    filename=os.path.join(args.model_path, 'epoch%d.pth.tar' % epoch), 
                    keep_all='k400' in args.dataset)
    
    print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs))
    sys.exit(0)
示例#26
0
def get_dataloaders(checkpoint_dir,
                    rsyncing,
                    selective_sampling=False,
                    warmup_trainer=None,
                    batch_size=16,
                    num_workers=os.cpu_count() - 1,
                    data_aug_vec=[0.5, 0.25, 0.5, 0.5],
                    toy=False,
                    notebook=False,
                    cat=False):
    """

    :param checkpoint_dir:
    :param rsyncing:
    :param selective_sampling:
    :param warmup_trainer:
    :param batch_size:
    :param num_workers:
    :param seed:
    :param data_aug_vec: probabilities for rnd flip, rnd gamma, rnd translation and rnd scale
    :param toy:
    :param notebook:
    :return:
    """
    #     if torch.cuda.is_available():
    #         mp.set_start_method('spawn')
    multiprocessing = False
    num_workers = 0
    sampler_size = 3000

    if rsyncing:
        print('Rsynced data! (prepare feat)', flush=True)
    else:
        print('Using symbolic links! (prepare feat)', flush=True)
    print('Getting path ready..', flush=True)
    anno_path_train, anno_path_val, png_path = get_paths(
        rsyncing, toy, notebook)

    # TODO
    # png_path = os.path.join('/Users/lisa/Documents/Uni/ThesisDS/thesis_ds/one_img_dataset', 'png')
    # anno_path_train = os.path.join('/Users/lisa/Documents/Uni/ThesisDS/thesis_ds/one_img_dataset',
    #                                'annotations/mscoco_train_full.json')
    # anno_path_val = os.path.join('/Users/lisa/Documents/Uni/ThesisDS/thesis_ds/one_img_dataset',
    #                                'annotations/mscoco_train_full.json')

    print('Creating Coco Datasets..', flush=True)
    # t.ToTensor()
    if not cat:
        trans_img = torchvision.transforms.Compose([
            t.Normalize(),
            t.BboxCrop(targetsize=224),
            t.RandomFlipImg(prob=data_aug_vec[0]),
            t.RandomGammaImg(prob=data_aug_vec[1],
                             use_normal_distribution=True)
        ])
        trans_bb = torchvision.transforms.Compose([
            t.GetFiveBBs(),
            t.RandomTranslateBB(prob=data_aug_vec[2], pixel_range=10),
            t.RandomScaleBB(prob=data_aug_vec[3], max_percentage=0.1)
        ])
    else:
        trans_img = torchvision.transforms.Compose([
            t.Normalize(),
            t.BboxCropMult(targetsize=224),
            t.RandomFlipImg(prob=data_aug_vec[0]),
            t.RandomGammaImg(prob=data_aug_vec[1],
                             use_normal_distribution=True)
        ])
        trans_bb = torchvision.transforms.Compose([
            t.GetBBsMult(),
            t.RandomTranslateBB(prob=data_aug_vec[2], pixel_range=10,
                                cat=True),
            t.RandomScaleBB(prob=data_aug_vec[3], max_percentage=0.1, cat=True)
        ])

    trainset = u.dataset_coco(png_path,
                              anno_path_train,
                              transform=trans_img,
                              bbox_transform=trans_bb,
                              for_feature=True,
                              cat=cat)
    print('Training set has', len(trainset), 'images', flush=True)

    if not cat:
        valset = u.dataset_coco(
            png_path,
            anno_path_val,
            transform=torchvision.transforms.Compose(
                [t.Normalize(), t.BboxCrop(targetsize=224)]),
            bbox_transform=torchvision.transforms.Compose([t.GetFiveBBs()]),
            for_feature=True,
            cat=cat)
    else:
        valset = u.dataset_coco(
            png_path,
            anno_path_val,
            transform=torchvision.transforms.Compose(
                [t.Normalize(), t.BboxCropMult(targetsize=224)]),
            bbox_transform=torchvision.transforms.Compose([t.GetBBsMult()]),
            for_feature=True,
            cat=cat)
    print('Validation set has', len(valset), 'images', flush=True)

    if selective_sampling:
        if not warmup_trainer:
            print(
                'Cannot calculate weights for selective sampling: no model given. Using normal sampling instead',
                flush=True)
            trainloader = torch.utils.data.DataLoader(
                trainset,
                batch_size=batch_size,
                sampler=RandomSampler(trainset),
                num_workers=num_workers,
                collate_fn=u.mammo_collate,
                pin_memory=multiprocessing)
        else:
            print('Getting weights for sampling..', flush=True)
            trainloader = torch.utils.data.DataLoader(
                trainset,
                batch_size=batch_size,
                sampler=SequentialSampler(trainset),
                num_workers=num_workers,
                collate_fn=u.mammo_collate,
                pin_memory=multiprocessing)
            weights = warmup_trainer.predict_dataset(trainloader)
            pkl.dump(
                weights,
                open(
                    os.path.join(checkpoint_dir,
                                 'weights_selective_train.pkl'), 'wb'))
            trainloader = torch.utils.data.DataLoader(
                trainset,
                batch_size=batch_size,
                sampler=WeightedRandomSampler(weights.double(),
                                              sampler_size,
                                              replacement=False),
                num_workers=num_workers,
                collate_fn=u.mammo_collate,
                pin_memory=multiprocessing)

    else:
        trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=batch_size,
            sampler=RandomSampler(trainset),
            num_workers=num_workers,
            collate_fn=u.mammo_collate,
            pin_memory=multiprocessing)

    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=batch_size,
                                            sampler=SequentialSampler(valset),
                                            num_workers=num_workers,
                                            collate_fn=u.mammo_collate,
                                            pin_memory=multiprocessing)

    print('Training loader has', len(trainloader), 'batches', flush=True)
    print('Validation loader has', len(valloader), 'batches', flush=True)
    return trainloader, valloader
示例#27
0
with open('data/kannada_semi_1pct.pkl', 'rb') as f:
    kannada_semi = pickle.load(f)
kannada_x_train_labeled = kannada_semi["x_train_labeled"]
kannada_y_train_labeled = kannada_semi["y_train_labeled"]
kannada_x_train_unlabeled = kannada_semi["x_train_unlabeled"]
kannada_y_train_unlabeled = kannada_semi["y_train_unlabeled"]
kannada_x_train = np.concatenate((kannada_x_train_labeled, kannada_x_train_unlabeled), axis=0)
kannada_y_train = np.concatenate((kannada_y_train_labeled, kannada_y_train_unlabeled), axis=0)
kannada_x_val = kannada_semi["x_val"]
kannada_y_val = kannada_semi["y_val"]
kannada_x_test = kannada_semi['x_test']
kannada_y_test = kannada_semi['y_test']

train_transform = transforms.Compose([
    transforms.Normalize(mean=0.5, std=0.5),
    transforms.Resize(size=(32, 32))
])

mnist_train_dataset = CustomTensorDataset(torch.from_numpy(mnist_x_train).float(), torch.from_numpy(mnist_y_train).long(), transform=train_transform)
kannada_train_dataset = CustomTensorDataset(torch.from_numpy(kannada_x_train).float(), torch.from_numpy(kannada_y_train).long(), transform=train_transform)
kannada_val_dataset = CustomTensorDataset(torch.from_numpy(kannada_x_val).float(), torch.from_numpy(kannada_y_val).long(), transform=train_transform)

train_loader_1 = DataLoader(mnist_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
train_loader_2 = DataLoader(kannada_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_loader = DataLoader(kannada_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

num_batch = min(len(train_loader_1), len(train_loader_2))

dim_z = 512
示例#28
0
def get_sequential_trainloader(toy,
                               rsyncing,
                               batch_size=16,
                               num_workers=os.cpu_count() - 1,
                               data_aug_vec=[0.5, 0.25, 0.5, 0.5],
                               notebook=False):
    """

    :param toy:
    :param rsyncing:
    :param batch_size:
    :param num_workers:
    :param data_aug_vec:
    :param notebook:
    :return:
    """
    num_workers = 0
    if rsyncing:
        print('Rsynced data! (prepare feat)', flush=True)
    else:
        print('Using symbolic links! (prepare feat)', flush=True)
    print('Getting path ready..', flush=True)
    anno_path_train, _, png_path = get_paths(rsyncing, toy, notebook)

    # TODO
    # png_path = os.path.join('/Users/lisa/Documents/Uni/ThesisDS/thesis_ds/one_img_dataset', 'png')
    # anno_path_train = os.path.join('/Users/lisa/Documents/Uni/ThesisDS/thesis_ds/one_img_dataset',
    #                                'annotations/mscoco_train_full.json')

    trans_img = torchvision.transforms.Compose([
        t.Normalize(),
        t.BboxCrop(targetsize=224),
        t.RandomFlipImg(prob=data_aug_vec[0]),
        t.RandomGammaImg(prob=data_aug_vec[1], use_normal_distribution=True)
    ])
    trans_bb = torchvision.transforms.Compose([
        t.GetFiveBBs(),
        t.RandomTranslateBB(prob=data_aug_vec[2], pixel_range=10),
        t.RandomScaleBB(prob=data_aug_vec[3], max_percentage=0.1)
    ])

    start_time = time.time()
    print('Creating Coco Dataset..', flush=True)

    trainset = u.dataset_coco(png_path,
                              anno_path_train,
                              transform=trans_img,
                              bbox_transform=trans_bb,
                              for_feature=True)
    print('Training set has', len(trainset), 'images', flush=True)

    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        sampler=SequentialSampler(trainset),
        num_workers=num_workers,
        collate_fn=u.mammo_collate)
    print('Training loader has', len(trainloader), 'batches', flush=True)

    total_time = time.time() - start_time
    print('Creating Datasets took {:.0f} seconds.'.format(total_time),
          flush=True)

    return trainloader
示例#29
0
def get_data(data_dir,
             source,
             target,
             source_train_path,
             target_train_path,
             source_extension,
             target_extension,
             height,
             width,
             batch_size,
             re=0,
             workers=8):

    dataset = DA(data_dir, source, target, source_train_path,
                 target_train_path, source_extension, target_extension)

    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    source_num_classes = dataset.num_source_train_ids
    train_transformer = T.Compose([
        T.RandomSizedRectCrop(height, width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        normalizer,
        T.RandomErasing(EPSILON=re),
    ])
    test_transformer = T.Compose([
        T.Resize((height, width), interpolation=3),
        T.ToTensor(),
        normalizer,
    ])
    source_train_loader = DataLoader(Preprocessor(
        dataset.source_train,
        root=osp.join(dataset.source_images_dir, dataset.source_train_path),
        transform=train_transformer),
                                     batch_size=batch_size,
                                     num_workers=0,
                                     shuffle=True,
                                     pin_memory=False,
                                     drop_last=True)
    target_train_loader = DataLoader(Preprocessor(
        dataset.target_train,
        root=osp.join(dataset.target_images_dir, dataset.target_train_path),
        transform=train_transformer),
                                     batch_size=batch_size,
                                     num_workers=0,
                                     shuffle=True,
                                     pin_memory=False,
                                     drop_last=True)
    # source_train_loader = DataLoader(
    #     UnsupervisedCamStylePreprocessor(dataset.source_train, root=osp.join(dataset.source_images_dir, dataset.source_train_path),
    #                                      camstyle_root=osp.join(dataset.source_images_dir, dataset.source_train_path),
    #                  transform=train_transformer),
    #     batch_size=batch_size, num_workers=0,
    #     shuffle=True, pin_memory=False, drop_last=True)
    # target_train_loader = DataLoader(
    #     UnsupervisedCamStylePreprocessor(dataset.target_train,
    #                                      root=osp.join(dataset.target_images_dir, dataset.target_train_path),
    #                                      camstyle_root=osp.join(dataset.target_images_dir,
    #                                                             dataset.target_train_camstyle_path),
    #                                      num_cam=dataset.target_num_cam, transform=train_transformer),
    #     batch_size=batch_size, num_workers=workers,
    #     shuffle=True, pin_memory=True, drop_last=True)
    query_loader = DataLoader(Preprocessor(dataset.query,
                                           root=osp.join(
                                               dataset.target_images_dir,
                                               dataset.query_path),
                                           transform=test_transformer),
                              batch_size=batch_size,
                              num_workers=workers,
                              shuffle=False,
                              pin_memory=True)
    gallery_loader = DataLoader(Preprocessor(dataset.gallery,
                                             root=osp.join(
                                                 dataset.target_images_dir,
                                                 dataset.gallery_path),
                                             transform=test_transformer),
                                batch_size=batch_size,
                                num_workers=workers,
                                shuffle=False,
                                pin_memory=True)
    return dataset, source_num_classes, source_train_loader, target_train_loader, query_loader, gallery_loader
示例#30
0
def main(args):
    if args.gpu is None:
        args.gpu = str(os.environ["CUDA_VISIBLE_DEVICES"])
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device('cuda')

    best_acc = 0
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    num_gpu = len(str(args.gpu).split(','))
    args.batch_size = num_gpu * args.batch_size
    print('=> Effective BatchSize = %d' % args.batch_size)
    args.img_path, args.model_path, args.exp_path = set_path(args)

    ### classifier model ###
    num_class_dict = {
        'ucf101': 101,
        'hmdb51': 51,
        'k400': 400,
        'ucf101-f': 101,
        'hmdb51-f': 51,
        'k400-f': 400
    }
    args.num_class = num_class_dict[args.dataset]

    if args.train_what == 'last':  # for linear probe
        args.final_bn = True
        args.final_norm = True
        args.use_dropout = False
    else:  # for training the entire network
        args.final_bn = False
        args.final_norm = False
        args.use_dropout = True

    if args.model == 'lincls':
        model = LinearClassifier(network=args.net,
                                 num_class=args.num_class,
                                 dropout=args.dropout,
                                 use_dropout=args.use_dropout,
                                 use_final_bn=args.final_bn,
                                 use_l2_norm=args.final_norm)
    else:
        raise NotImplementedError

    model.to(device)

    ### optimizer ###
    if args.train_what == 'last':
        print('=> [optimizer] only train last layer')
        params = []
        for name, param in model.named_parameters():
            if 'backbone' in name:
                param.requires_grad = False
            else:
                params.append({'params': param})

    elif args.train_what == 'ft':
        print('=> [optimizer] finetune backbone with smaller lr')
        params = []
        for name, param in model.named_parameters():
            if 'backbone' in name:
                params.append({'params': param, 'lr': args.lr / 10})
            else:
                params.append({'params': param})

    else:  # train all
        params = []
        print('=> [optimizer] train all layer')
        for name, param in model.named_parameters():
            params.append({'params': param})

    if args.train_what == 'last':
        print('\n===========Check Grad============')
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(name, param.requires_grad)
        print('=================================\n')

    if args.optim == 'adam':
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(params,
                              lr=args.lr,
                              weight_decay=args.wd,
                              momentum=0.9)
    else:
        raise NotImplementedError

    model = torch.nn.DataParallel(model)
    model_without_dp = model.module

    ce_loss = nn.CrossEntropyLoss()
    args.iteration = 1

    ### test: higher priority ###
    if args.test:
        if os.path.isfile(args.test):
            print("=> loading testing checkpoint '{}'".format(args.test))
            checkpoint = torch.load(args.test,
                                    map_location=torch.device('cpu'))
            epoch = checkpoint['epoch']
            state_dict = checkpoint['state_dict']

            if args.retrieval_ucf or args.retrieval_full:  # if directly test on pretrained network
                new_dict = {}
                for k, v in state_dict.items():
                    k = k.replace('encoder_q.0.', 'backbone.')
                    new_dict[k] = v
                state_dict = new_dict

            try:
                model_without_dp.load_state_dict(state_dict)
            except:
                neq_load_customized(model_without_dp, state_dict, verbose=True)

        else:
            print("[Warning] no checkpoint found at '{}'".format(args.test))
            epoch = 0
            print("[Warning] if test random init weights, press c to continue")
            import ipdb
            ipdb.set_trace()

        args.logger = Logger(path=os.path.dirname(args.test))
        args.logger.log('args=\n\t\t' + '\n\t\t'.join(
            ['%s:%s' % (str(k), str(v)) for k, v in vars(args).items()]))

        transform_test_cuda = transforms.Compose([
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                        channel=1)
        ])

        if args.retrieval:
            test_retrieval(model, ce_loss, transform_test_cuda, device, epoch,
                           args)
        elif args.center_crop or args.five_crop or args.ten_crop:
            transform = get_transform('test', args)
            test_dataset = get_data(transform, 'test', args)
            test_10crop(test_dataset, model, ce_loss, transform_test_cuda,
                        device, epoch, args)
        else:
            raise NotImplementedError

        sys.exit(0)

    ### data ###
    transform_train = get_transform('train', args)
    train_loader = get_dataloader(get_data(transform_train, 'train', args),
                                  'train', args)
    transform_val = get_transform('val', args)
    val_loader = get_dataloader(get_data(transform_val, 'val', args), 'val',
                                args)

    transform_train_cuda = transforms.Compose([
        T.RandomHorizontalFlip(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    channel=1)
    ])  # ImageNet
    transform_val_cuda = transforms.Compose([
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    channel=1)
    ])  # ImageNet

    print('===================================')

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch'] + 1
            args.iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            state_dict = checkpoint['state_dict']

            try:
                model_without_dp.load_state_dict(state_dict)
            except:
                print('[WARNING] resuming training with different weights')
                neq_load_customized(model_without_dp, state_dict, verbose=True)
            print("=> load resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
            except:
                print(
                    '[WARNING] failed to load optimizer state, initialize optimizer'
                )
        else:
            print("[Warning] no checkpoint found at '{}', use random init".
                  format(args.resume))

    elif args.pretrain:
        if os.path.isfile(args.pretrain):
            checkpoint = torch.load(args.pretrain, map_location='cpu')
            state_dict = checkpoint['state_dict']

            new_dict = {}
            for k, v in state_dict.items():
                k = k.replace('encoder_q.0.', 'backbone.')
                new_dict[k] = v
            state_dict = new_dict

            try:
                model_without_dp.load_state_dict(state_dict)
            except:
                neq_load_customized(model_without_dp, state_dict, verbose=True)
            print("=> loaded pretrained checkpoint '{}' (epoch {})".format(
                args.pretrain, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}', use random init".
                  format(args.pretrain))

    else:
        print("=> train from scratch")

    torch.backends.cudnn.benchmark = True

    # plot tools
    writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val'))
    writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train'))
    args.val_plotter = TB.PlotterThread(writer_val)
    args.train_plotter = TB.PlotterThread(writer_train)

    args.logger = Logger(path=args.img_path)
    args.logger.log('args=\n\t\t' + '\n\t\t'.join(
        ['%s:%s' % (str(k), str(v)) for k, v in vars(args).items()]))

    # main loop
    for epoch in range(args.start_epoch, args.epochs):
        np.random.seed(epoch)
        random.seed(epoch)

        adjust_learning_rate(optimizer, epoch, args)

        train_one_epoch(train_loader, model, ce_loss, optimizer,
                        transform_train_cuda, device, epoch, args)

        if epoch % args.eval_freq == 0:
            _, val_acc = validate(val_loader, model, ce_loss,
                                  transform_val_cuda, device, epoch, args)

            # save check_point
            is_best = val_acc > best_acc
            best_acc = max(val_acc, best_acc)
            state_dict = model_without_dp.state_dict()
            save_dict = {
                'epoch': epoch,
                'state_dict': state_dict,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'iteration': args.iteration
            }
            save_checkpoint(save_dict,
                            is_best,
                            1,
                            filename=os.path.join(args.model_path,
                                                  'epoch%d.pth.tar' % epoch),
                            keep_all=False)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
    sys.exit(0)