Ejemplo n.º 1
0
def train(opt):
    seq = iaa.Sequential([
        iaa.CropToFixedSize(opt.fineSize, opt.fineSize),
    ])
    dataset_train = ImageDataset(opt.source_root_train,
                                 opt.gt_root_train,
                                 transform=seq)
    dataset_test = ImageDataset(opt.source_root_test,
                                opt.gt_root_test,
                                transform=seq)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=opt.batchSize,
                                  shuffle=True,
                                  num_workers=opt.nThreads)
    dataloader_test = DataLoader(dataset_test,
                                 batch_size=opt.batchSize,
                                 shuffle=False,
                                 num_workers=opt.nThreads)
    model = StainNet(opt.input_nc, opt.output_nc, opt.n_layer, opt.channels)
    model = nn.DataParallel(model).cuda()
    optimizer = SGD(model.parameters(), lr=opt.lr)
    loss_function = torch.nn.L1Loss()
    lrschedulr = lr_scheduler.CosineAnnealingLR(optimizer, opt.epoch)
    vis = Visualizer(env=opt.name)
    best_psnr = 0
    for i in range(opt.epoch):
        for j, (source_image,
                target_image) in tqdm(enumerate(dataloader_train)):
            target_image = target_image.cuda()
            source_image = source_image.cuda()
            output = model(source_image)
            loss = loss_function(output, target_image)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (j + 1) % opt.display_freq == 0:
                vis.plot("loss", float(loss))
                vis.img("target image", target_image[0] * 0.5 + 0.5)
                vis.img("source image", source_image[0] * 0.5 + 0.5)
                vis.img("output", (output[0] * 0.5 + 0.5).clamp(0, 1))
        if (i + 1) % 5 == 0:
            test_result = test(model, dataloader_test)
            vis.plot_many(test_result)
            if best_psnr < test_result["psnr"]:
                save_path = "{}/{}_best_psnr_layer{}_ch{}.pth".format(
                    opt.checkpoints_dir, opt.name, opt.n_layer, opt.channels)
                best_psnr = test_result["psnr"]
                torch.save(model.module.state_dict(), save_path)
                print(save_path, test_result)
        lrschedulr.step()
        print("lrschedulr=", lrschedulr.get_last_lr())
Ejemplo n.º 2
0
 def train_dataloader(self):
     if self.hparams.homo_u:
         # must set trainer flag reload_dataloaders_every_epoch=True
         if self.train_dataset is None:
             self.train_dataset = HomoImageDataset(self.data_path,
                                                   self.hparams.T_pred)
         if self.current_epoch < 1000:
             # feed zero ctrl dataset and ctrl dataset in turns
             if self.current_epoch % 2 == 0:
                 u_idx = 0
             else:
                 u_idx = self.non_ctrl_ind
                 self.non_ctrl_ind += 1
                 if self.non_ctrl_ind == 9:
                     self.non_ctrl_ind = 1
         else:
             u_idx = self.current_epoch % 9
         self.train_dataset.u_idx = u_idx
         self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
         return DataLoader(self.train_dataset,
                           batch_size=self.hparams.batch_size,
                           shuffle=True,
                           collate_fn=my_collate)
     else:
         train_dataset = ImageDataset(self.data_path,
                                      self.hparams.T_pred,
                                      ctrl=True)
         self.t_eval = torch.from_numpy(train_dataset.t_eval)
         return DataLoader(train_dataset,
                           batch_size=self.hparams.batch_size,
                           shuffle=True,
                           collate_fn=my_collate)
def main():
    # Create dataset
    content_ds = ImageDataset(CONTENT_DS_PATH, batch_size=BATCH_SIZE)
    style_ds = ImageDataset(STYLE_DS_PATH, batch_size=BATCH_SIZE)

    # Build model
    vgg19 = build_vgg19(INPUT_SHAPE, VGG_PATH)  # encoder
    decoder = build_decoder(vgg19.output.shape[1:])  # input shape == encoder output shape
    model = build_model(vgg19, decoder, INPUT_SHAPE)

    #model.load_weights(SAVE_PATH)

    # Get loss
    vgg19_relus = build_vgg19_relus(vgg19)
    loss = get_loss(vgg19, vgg19_relus, epsilon=EPSILON, style_weight=STYLE_WEIGHT, color_weight=COLOR_LOSS)

    # Train model
    train(model, content_ds, style_ds, loss, n_epochs=EPOCHS, save_path=SAVE_PATH)
Ejemplo n.º 4
0
 def train_dataloader(self):
     train_dataset = ImageDataset(self.data_path, self.hparams.T_pred, ctrl=False)
     # self.us = data['us']
     # self.u = self.us[self.idx]
     self.t_eval = torch.from_numpy(train_dataset.t_eval)
     return DataLoader(train_dataset, 
                       batch_size=self.hparams.batch_size, 
                       shuffle=True, 
                       collate_fn=my_collate,
                       drop_last=True,
                       num_workers=4)
Ejemplo n.º 5
0
def train():
    """Train"""
    client = storage.Client(PROJECT)
    raw_bucket = client.get_bucket(RAW_BUCKET)
    bucket = client.get_bucket(BUCKET)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"  Device found = {device}")

    metadata_df = (
        pd.read_csv(f"gs://{RAW_BUCKET}/{RAW_DATA_DIR}/metadata.csv").query(
            "view == 'PA'")  # taking only PA view
    )

    print("Split train and validation data")
    proc_data = ImageDataset(
        root_dir=BASE_DIR,
        image_dir=PREPROCESSED_DIR,
        df=metadata_df,
        bucket=bucket,
        transform=ToTensor(),
    )
    seed_torch(seed=42)
    valid_size = int(len(proc_data) * 0.2)
    train_data, valid_data = torch.utils.data.random_split(
        proc_data, [len(proc_data) - valid_size, valid_size])
    train_loader = DataLoader(train_data,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              drop_last=True)
    valid_loader = DataLoader(valid_data,
                              batch_size=CFG.batch_size,
                              shuffle=False)

    print("Train model")
    se_model_blob = raw_bucket.blob(CFG.pretrained_weights)
    model = CustomSEResNeXt(
        BytesIO(se_model_blob.download_as_string()),
        device,
        CFG.n_classes,
        save=CFG.pretrained_model_path,
    )
    train_fn(model, train_loader, valid_loader, device)

    print("Evaluate")
    y_probs, y_val = predict(model, valid_loader, device)
    y_preds = y_probs.argmax(axis=1)

    compute_log_metrics(y_val, y_probs[:, 1], y_preds)
Ejemplo n.º 6
0
def main(args):

    transform = Compose([
        Resize(256),
        CenterCrop(224),
        ToTensor(),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    dataset = ImageDataset(args.image_folder, transform=transform)
    n_images = len(dataset)
    dataloader = DataLoader(dataset,
                            shuffle=False,
                            batch_size=32,
                            pin_memory=True,
                            num_workers=8)

    resnet = torchvision.models.resnet50(pretrained=True).to(args.device)
    features = h5py.File(args.out)

    blocks = itertools.chain(resnet.layer1, resnet.layer2, resnet.layer3,
                             resnet.layer4, (resnet.avgpool, ))
    blocks = list(blocks)
    n_features = len(blocks)
    block_idx = dict(zip(blocks, map('{:02d}'.format, range(n_features))))

    n_processed = 0

    def extract(self, input, output):
        extracted = output
        if extracted.ndimension() > 2:
            extracted = F.avg_pool2d(
                extracted, extracted.shape[-2:]).squeeze(3).squeeze(2)

        block_num = block_idx[self]
        batch_size, feature_dims = extracted.shape
        dset = features.require_dataset(block_num, (n_images, feature_dims),
                                        dtype='float32')
        # extracted = extracted.to('cpu')
        dset[n_processed:n_processed + batch_size, :] = extracted.to('cpu')

    for b in blocks:
        b.register_forward_hook(extract)

    with torch.no_grad():
        for x in tqdm(dataloader):
            resnet(x.to(args.device))
            n_processed += x.shape[0]
Ejemplo n.º 7
0
    def __init__(self, training_config, G, D, device):
        self.device = device
        self.config = SimpleNamespace(**training_config)
        self.exp_dir = Path("2_experiments") / self.config.exp_dir
        self.exp_dir.mkdir(parents=True, exist_ok=True)
        self.stat = {}  # training statistics
        self.logger = init_logger("training",
                                  self.exp_dir / "log")  # training logger

        dataset = ImageDataset(self.config.dataset)
        self.data_loader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_loader_worker)

        self.G = G.to(self.device)
        self.D = D.to(self.device)
        self.logger.info("Latent Size: {}".format(self.G.get_input_shape()))
        self.optimizer_G = self.get_optimizer(self.G, self.config.optimizer_G)
        self.optimizer_D = self.get_optimizer(self.D, self.config.optimizer_D)

        self.loss = nn.BCELoss().to(self.device)
Ejemplo n.º 8
0
def main(args):
    normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transform = Compose([Resize(256), CenterCrop(224), ToTensor(), normalize])

    dataset = ImageDataset(args.image_folder,
                           transform=transform,
                           return_paths=True)
    # n_images = len(dataset)
    dataloader = DataLoader(dataset,
                            shuffle=False,
                            batch_size=args.batch_size,
                            pin_memory=True,
                            num_workers=0)

    model = models.resnet50(pretrained=True).to(args.device)
    model.eval()

    config = tf.ConfigProto(intra_op_parallelism_threads=1,
                            inter_op_parallelism_threads=1,
                            allow_soft_placement=True,
                            device_count={'CPU': 1})
    sess = tf.Session(config=config)
    x_op = tf.placeholder(tf.float32, shape=(
        None,
        3,
        224,
        224,
    ))

    tf_model = convert_pytorch_model_to_tf(model, args.device)
    cleverhans_model = CallableModelWrapper(tf_model, output_layer='logits')

    # compute clip_min and clip_max suing a full black and a full white image
    clip_min = normalize(torch.zeros(3, 1, 1)).min().item()
    clip_max = normalize(torch.ones(3, 1, 1)).max().item()

    eps = args.eps / 255.
    eps_iter = 20
    nb_iter = 10
    args.ord = np.inf if args.ord < 0 else args.ord
    grad_params = {'eps': eps, 'ord': args.ord}
    common_params = {'clip_min': clip_min, 'clip_max': clip_max}
    iter_params = {'eps_iter': eps_iter / 255., 'nb_iter': nb_iter}

    attack_name = ''
    if args.attack == 'fgsm':
        attack_name = '_L{}_eps{}'.format(args.ord, args.eps)
        attack_op = FastGradientMethod(cleverhans_model, sess=sess)
        attack_params = {**common_params, **grad_params}
    elif args.attack == 'iter':
        attack_name = '_L{}_eps{}_epsi{}_i{}'.format(args.ord, args.eps,
                                                     eps_iter, nb_iter)
        attack_op = BasicIterativeMethod(cleverhans_model, sess=sess)
        attack_params = {**common_params, **grad_params, **iter_params}
    elif args.attack == 'm-iter':
        attack_name = '_L{}_eps{}_epsi{}_i{}'.format(args.ord, args.eps,
                                                     eps_iter, nb_iter)
        attack_op = MomentumIterativeMethod(cleverhans_model, sess=sess)
        attack_params = {**common_params, **grad_params, **iter_params}
    elif args.attack == 'pgd':
        attack_name = '_L{}_eps{}_epsi{}_i{}'.format(args.ord, args.eps,
                                                     eps_iter, nb_iter)
        attack_op = MadryEtAl(cleverhans_model, sess=sess)
        attack_params = {**common_params, **grad_params, **iter_params}
    elif args.attack == 'jsma':
        attack_op = SaliencyMapMethod(cleverhans_model, sess=sess)
        attack_params = {'theta': eps, 'symbolic_impl': False, **common_params}
    elif args.attack == 'deepfool':
        attack_op = DeepFool(cleverhans_model, sess=sess)
        attack_params = common_params
    elif args.attack == 'cw':
        attack_op = CarliniWagnerL2(cleverhans_model, sess=sess)
        attack_params = common_params
    elif args.attack == 'lbfgs':
        attack_op = LBFGS(cleverhans_model, sess=sess)
        target = np.zeros((1, 1000))
        target[0, np.random.randint(1000)] = 1
        y = tf.placeholder(tf.float32, target.shape)
        attack_params = {'y_target': y, **common_params}

    attack_name = args.attack + attack_name

    print('Running [{}]. Params: {}'.format(args.attack.upper(),
                                            attack_params))

    adv_x_op = attack_op.generate(x_op, **attack_params)
    adv_preds_op = tf_model(adv_x_op)
    preds_op = tf_model(x_op)

    n_success = 0
    n_processed = 0
    progress = tqdm(dataloader)
    for paths, x in progress:

        progress.set_description('ATTACK')

        z, adv_x, adv_z = sess.run([preds_op, adv_x_op, adv_preds_op],
                                   feed_dict={
                                       x_op: x,
                                       y: target
                                   })

        src, dst = np.argmax(z, axis=1), np.argmax(adv_z, axis=1)
        success = src != dst
        success_paths = np.array(paths)[success]
        success_adv_x = adv_x[success]
        success_src = src[success]
        success_dst = dst[success]

        n_success += success_adv_x.shape[0]
        n_processed += x.shape[0]

        progress.set_postfix(
            {'Success': '{:3.2%}'.format(n_success / n_processed)})
        progress.set_description('SAVING')

        for p, a, s, d in zip(success_paths, success_adv_x, success_src,
                              success_dst):
            path = '{}_{}_src{}_dst{}.npz'.format(p, attack_name, s, d)
            path = os.path.join(args.out_folder, path)
            np.savez_compressed(path, img=a)
Ejemplo n.º 9
0
def train():
    dataset = ImageDataset('./dataset/train', args)
    trainDataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=args['batchsize'],
        shuffle=True,
        num_workers=args['num_works'],
        pin_memory=True,
    )
    val = ValImageDataset('./dataset/val', args)
    valDataLoader = torch.utils.data.DataLoader(
        dataset=val,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
    )
    model = DCNN(16).to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args['lr'],
                                weight_decay=1e-4,
                                momentum=0.9)
    criterion = MseLoss().to(device)
    best_loss = np.inf
    best_psnr = -1
    for epoch in range(args['epochs']):
        lr = adjust_learning_rate(optimizer, epoch)
        logger.info('Epoch<%d/%d> current lr:%f', epoch + 1, args['epochs'],
                    lr)
        model.train()
        train_loss = []
        for i, (image, gt, _) in enumerate(trainDataLoader):
            image, gt = image.to(device), gt.to(device)
            pre = model(image)
            pre = torch.tanh(pre)
            loss = criterion(pre, gt)
            train_loss.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            logger.info('Epoch<%d/%d>|Step<%d/%d> avg loss:%.6f' %
                        (epoch + 1, args['epochs'], i + 1,
                         np.ceil(trainDataLoader.dataset.__len__() /
                                 args['batchsize']), np.mean(train_loss)))
        model.eval()
        val_loss = []
        PSNR = []
        with torch.no_grad():
            for j, (image, gt, _) in tqdm(enumerate(valDataLoader),
                                          total=len(valDataLoader)):
                image, gt = image.to(device), gt.to(device)
                pre = model(image)
                pre = torch.tanh(pre)
                loss = criterion(pre, gt)
                val_loss.append(loss.item())
                PSNR.append(
                    criCalc.caclBatchPSNR((pre + 1) * 127.5, (gt + 1) * 127.5))
            if (np.mean(val_loss) <= best_loss):
                best_loss = np.mean(val_loss)
                model_related_loss = {
                    'model': model.state_dict(),
                    'epoch': epoch + 1,
                    'best_loss': best_loss,
                    'best_psnr': -1
                }
                torch.save(model_related_loss,
                           args['model'] + '/IRS_best_loss.pkl')
            if (np.mean(PSNR) >= best_psnr):
                best_psnr = np.mean(PSNR)
                model_related_psnr = {
                    'model': model.state_dict(),
                    'epoch': epoch + 1,
                    'best_loss': -1,
                    'best_psnr': best_psnr
                }
                torch.save(model_related_psnr,
                           args['model'] + '/IRS_best_psnr.pkl')
            logger.info(
                'Epoch<%d/%d> current loss:%.6f, best loss:%.6f, current PSNR:%f(max:%f|min:%f), best PSNR:%f'
                % (epoch + 1, args['epochs'], np.mean(val_loss), best_loss,
                   np.mean(PSNR), np.max(PSNR), np.min(PSNR), best_psnr))
Ejemplo n.º 10
0
def main():
    start_time = datetime.now()

    args = parse_args()
    print(args)

    # Compute all of the features of test set using pretrained model.
    print('Loading Model')
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Model 1
    embed_1, encoder_1 = model_initialization(
        model_name='efficientnet-b3',
        ckpt_path='./all_ckpts/efficientnet-b3_multimask4.pkl',
        device=device)
    embed_2, encoder_2 = model_initialization(
        model_name='efficientnet-b2',
        ckpt_path='./all_ckpts/efficientnet-b2_multimask4.pkl',
        device=device)
    embed_3, encoder_3 = model_initialization(
        model_name='resnet152',
        ckpt_path='./all_ckpts/resnet152_multimask2.pkl',
        device=device)
    embed_4, encoder_4 = model_initialization(
        model_name='resnet101',
        ckpt_path='./all_ckpts/resnet101_multimask3.pkl',
        device=device)
    print('Done')

    # Load testset
    print('Loading Testset...')
    SIZE = 224
    BATCH_SIZE = 4
    data_transforms = transforms.Compose([
        transforms.Resize((SIZE, SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    def collate_fn(data):
        img_names, imgs = list(zip(*data))
        imgs = torch.stack(imgs)
        return img_names, imgs

    dataset = ImageDataset(dataset_path=Path(args.test_path),
                           istrain=False,
                           transforms=data_transforms)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             collate_fn=collate_fn)
    print('Done')

    # Acquire testset features.
    print('Getting testset features...')
    list_feat_vector_df = []
    with tqdm(total=len(dataset)) as pbar:
        for i, (img_names, imgs) in enumerate(dataloader):
            pbar.update(BATCH_SIZE)

            with torch.no_grad():
                z1 = get_features(embed_1, encoder_1, device, imgs)
                z2 = get_features(embed_2, encoder_2, device, imgs)
                z3 = get_features(embed_3, encoder_3, device, imgs)
                z4 = get_features(embed_4, encoder_4, device, imgs)

                z = np.concatenate([z1, z2, z3, z4], axis=1)

                gc.disable()  # Disable the garbage collection
                for j, (img_name, z_) in enumerate(zip(img_names, z)):
                    list_feat_vector_df.append(
                        pd.DataFrame({
                            'img_name': img_name,
                            'z': [z_]
                        }))
                gc.enable()
    feat_vector_df = pd.concat(list_feat_vector_df)
    img_name_query_np = feat_vector_df.img_name.values

    feat_query = np.stack(list(feat_vector_df.z.values), axis=0)  # [num, c]
    feat_query = (feat_query - feat_query.mean(
        0, keepdims=True)) / feat_query.std(0, keepdims=True)

    # Load Database Features
    print('Load database features...')
    img_name_db_np, z1_np = get_database_features(
        feature_path='./feature/efficientnet-b3_multimask_last.hdf5',
        model_name='efficientnet-b3')
    img_name_db_np, z2_np = get_database_features(
        feature_path='./feature/efficientnet-b2_multimask_last.hdf5',
        model_name='efficientnet-b2')
    img_name_db_np, z3_np = get_database_features(
        feature_path='./feature/resnet152_multimask_last.hdf5',
        model_name='resnet152')
    img_name_db_np, z4_np = get_database_features(
        feature_path='./feature/resnet101_multimask_last.hdf5',
        model_name='resnet101')

    feat_keys = np.concatenate([z1_np, z2_np, z3_np, z4_np],
                               axis=1)  # [num, c]
    feat_keys = (feat_keys - feat_keys.mean(0, keepdims=True)) / feat_keys.std(
        0, keepdims=True)

    # QE
    # print('Query Expansion')
    # QE_num = 2
    # index_flat = faiss.IndexFlatL2(feat_keys.shape[1])
    # index_flat.add(feat_keys)
    # D, I = index_flat.search(feat_query, QE_num - 1)
    # new_feat_query=copy.deepcopy(feat_query)
    # for num in range(len(new_feat_query)):
    #     new_feat_query[num] = (new_feat_query[num] + feat_keys[I[num][0]]) / float(QE_num)
    # print('Done')

    # Image Retrival
    TOP_K = 7
    print('Image retrival...')
    index = faiss.IndexFlatL2(feat_keys.shape[1])
    index.add(feat_keys)
    D, I = index.search(feat_query, TOP_K)

    print("Calculate top-7 query results.")
    # Save query results
    query_res = []
    for i, topk_idx in enumerate(I):
        img_query_id = img_name_query_np[i].split('.')[0]

        query_sample_res = [img_query_id]
        for j, jth_idx in enumerate(topk_idx):
            img_key_id = img_name_db_np[jth_idx].split('.')[0]
            query_sample_res.append(img_key_id)
        query_res.append(query_sample_res)

    query_res = pd.DataFrame(query_res)
    query_res.rename(columns={0: 'Query'}, inplace=True)
    query_res.to_csv(args.output_result_path, index=False, header=False)
    print('Done. Total seconds:{:}s'.format(
        (datetime.now() - start_time).total_seconds()))
Ejemplo n.º 11
0
    parser.add_argument('--decay',
                        default=100,
                        type=int,
                        help='Epoch to start linearly decaying lr to 0')
    parser.add_argument('--save_root',
                        default='result',
                        type=str,
                        help='Result saved root path')

    # args parse
    args = parser.parse_args()
    data_root, batch_size, epochs, lr = args.data_root, args.batch_size, args.epochs, args.lr
    decay, save_root = args.decay, args.save_root

    # data prepare
    train_data = ImageDataset(data_root, 'train')
    test_data = ImageDataset(data_root, 'test')
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8)
    test_loader = DataLoader(test_data,
                             batch_size=1,
                             shuffle=False,
                             num_workers=8)

    # model setup
    G_A = Generator(3, 3).cuda()
    G_B = Generator(3, 3).cuda()
    D_A = Discriminator(3).cuda()
    D_B = Discriminator(3).cuda()
Ejemplo n.º 12
0
#Test mode
netG_A2B.eval()
netG_B2A.eval()

#Inputs and targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)

#Load data
transform = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
dataloader = DataLoader(ImageDataset(opt.dataroot,
                                     transform=transform,
                                     mode='test'),
                        batch_size=opt.batchSize,
                        shuffle=False,
                        num_workers=opt.n_cpu)
####################

#######Testing#######

#Output directory
if not os.path.exists('output/A'):
    os.makedirs('output/A')
if not os.path.exists('output/B'):
    os.makedirs('output/B')

for i, batch in enumerate(dataloader):
Ejemplo n.º 13
0
def main():
    args = parse_args()
    print(args)

    # Compute all of the features of test set using pretrained model.
    print('Loading Model')
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model_name = args.model_name
    if model_name.startswith('efficientnet'):
        embed = MultiHeadEffNet(model=model_name, pretrained=False)
    elif model_name.startswith('resnet'):
        embed = MultiHeadResNet(model=model_name, pretrained=False)

    mask_encoder = MultiHeadMaskEncoder(model=embed.model, name=model_name)
    if args.model_ckpt:
        checkpoint = torch.load(args.model_ckpt, map_location='cpu')
        embed.load_state_dict(checkpoint['embed'])
        mask_encoder.load_state_dict(checkpoint['mask_encoder'])
        print('Load state dict.')

    embed.to(device)
    embed.eval()
    mask_encoder.to(device)
    mask_encoder.eval()
    print('Done')

    # Load testset
    print('Loading Testset...')
    SIZE = 224
    data_transforms = transforms.Compose([
        transforms.Resize((SIZE, SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    def collate_fn(data):
        img_names, imgs = list(zip(*data))
        imgs = torch.stack(imgs)
        return img_names, imgs

    dataset = ImageDataset(dataset_path=Path(args.test_path),
                           istrain=False,
                           transforms=data_transforms)
    BATCH_SIZE = 4
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             collate_fn=collate_fn)
    print('Done')

    # Acquire testset features.
    print('Getting testset features...')
    # Acquire the database features.

    if model_name.startswith('efficientnet'):
        valid_layer = [2, 3]
        img_name_np = []
        vec_length = np.array(embed.out_channels)[valid_layer].sum()
        z_att_np = np.zeros([len(dataset), vec_length], dtype='float32')
        z_att_max_np = np.zeros([len(dataset), vec_length], dtype='float32')
        z_max_np = np.zeros([len(dataset), embed.last_channels],
                            dtype='float32')

        with tqdm(total=len(dataset)) as pbar:
            for i, (img_names, imgs) in enumerate(dataloader):
                pbar.update(BATCH_SIZE)

                with torch.no_grad():
                    img_name_np.extend(img_names)
                    xs = embed(
                        imgs.to(device))  # x_representation:[b, 2048, h, w]
                    masks = mask_encoder(xs)  # [b, 1, h, w]
                    z_ATT = np.concatenate([
                        norm(
                            mask_encoder.attention_pooling(x, mask).squeeze(
                                3).squeeze(2).detach().cpu().numpy())
                        for i, (x, mask) in enumerate(zip(xs, masks))
                        if i in valid_layer
                    ],
                                           axis=1)
                    z_ATTMAX = np.concatenate([
                        norm(
                            F.adaptive_max_pool2d(x * mask, output_size=1).
                            squeeze(3).squeeze(2).detach().cpu().numpy())
                        for i, (x, mask) in enumerate(zip(xs, masks))
                        if i in valid_layer
                    ],
                                              axis=1)
                    z_MAX = norm(
                        F.adaptive_max_pool2d(xs[-1], output_size=1).squeeze(
                            3).squeeze(2).detach().cpu().numpy())

                for idx, (z_ATT_, z_ATTMAX_,
                          z_MAX_) in enumerate(zip(z_ATT, z_ATTMAX, z_MAX)):
                    z_att_np[i * BATCH_SIZE + idx, :] = z_ATT_
                    z_att_max_np[i * BATCH_SIZE + idx, :] = z_ATTMAX_
                    z_max_np[i * BATCH_SIZE + idx, :] = z_MAX_

        # Save the features
        img_name_query_np = np.array(img_name_np, dtype='object')
        z_np = np.concatenate([z_att_np, z_att_max_np, z_max_np], axis=1)
        feat_query = (z_np - z_np.mean(0, keepdims=True)) / z_np.std(
            0, keepdims=True)
        print('Done')

        # Load Database Features
        print('Load database features...')
        f = h5py.File(args.feature_path, 'r')
        img_name_db_np, z_att_np, z_att_max_np, z_max_np = \
            f['img_name_ds'][:], f['z_att_ds'][:, -vec_length:], f['z_att_max_ds'][:, -vec_length:], f['z_max_ds'][:]

        z_np = np.concatenate([z_att_np, z_att_max_np, z_max_np], axis=1)

        feat_keys = (z_np - z_np.mean(0, keepdims=True)) / z_np.std(
            0, keepdims=True)
        print('Done')

    elif model_name.startswith('resnet'):
        valid_layer = [3]
        img_name_np = []
        vec_length = np.array(embed.out_channels)[valid_layer].sum()
        z_att_np = np.zeros([len(dataset), vec_length], dtype='float32')
        # z_att_np = np.zeros([len(dataset), embed.out_channels[-1]], dtype='float32')

        with tqdm(total=len(dataset)) as pbar:
            for i, (img_names, imgs) in enumerate(dataloader):
                pbar.update(BATCH_SIZE)

                with torch.no_grad():
                    img_name_np.extend(img_names)
                    xs = embed(
                        imgs.to(device))  # x_representation:[b, 2048, h, w]
                    masks = mask_encoder(xs)  # [b, 1, h, w]
                    z_ATT = np.concatenate([
                        norm(
                            mask_encoder.attention_pooling(x, mask).squeeze(
                                3).squeeze(2).detach().cpu().numpy())
                        for i, (x, mask) in enumerate(zip(xs, masks))
                        if i in valid_layer
                    ],
                                           axis=1)
                    # z_ATT = norm(mask_encoder.attention_pooling(xs[-1], masks[-1]).squeeze(3).squeeze(2).detach().cpu().numpy())

                for idx, (z_ATT_) in enumerate(z_ATT):
                    z_att_np[i * BATCH_SIZE + idx, :] = z_ATT_

        # Save the features
        img_name_query_np = np.array(img_name_np, dtype='object')
        z_np = z_att_np
        feat_query = (z_np - z_np.mean(0, keepdims=True)) / z_np.std(
            0, keepdims=True)
        print('Done')

        # Load Database Features
        print('Load database features...')
        f = h5py.File(args.feature_path, 'r')
        img_name_db_np, z_att_np = f['img_name_ds'][:], f[
            'z_att_ds'][:, -vec_length:]
        # img_name_db_np, z_att_ds = f['img_name_ds'][:], f['z_att_ds'][:, -embed.out_channels[-1]:]
        z_np = z_att_np
        feat_keys = (z_np - z_np.mean(0, keepdims=True)) / z_np.std(
            0, keepdims=True)
        print('Done')

    # DBA
    # print('DBA')
    # DBA_num = 2
    # index_flat = faiss.IndexFlatL2(feat_keys.shape[1])
    # index_flat.add(feat_keys)
    # D, I = index_flat.search(feat_keys, DBA_num)
    #
    # new_feat_keys = copy.deepcopy(feat_keys)
    # for num in range(len(I)):
    #     new_feat = feat_keys[I[num][0]]
    #     for num1 in range(1, len(I[num])):
    #         weight = (len(I[num]) - num1) / float(len(I[num]))
    #         new_feat += feat_keys[num1] * weight
    #     new_feat_keys[num] = new_feat
    # print('Done')

    #QE
    # print('Query Expansion')
    # QE_num = 2
    # index_flat = faiss.IndexFlatL2(feat_keys.shape[1])
    # index_flat.add(feat_keys)
    # D, I = index_flat.search(feat_query, QE_num - 1)
    # new_feat_query=copy.deepcopy(feat_query)
    # for num in range(len(new_feat_query)):
    #     new_feat_query[num] = (new_feat_query[num] + feat_keys[I[num][0]]) / float(QE_num)
    # print('Done')

    # Image Retrival
    print('Image retrival...l2')
    index = faiss.IndexFlatL2(feat_keys.shape[1])
    index.add(feat_keys)

    TOP_K = 7
    D, I = index.search(feat_query, TOP_K)

    # Save query results
    query_res = []
    for i, topk_idx in enumerate(I):
        img_query_id = img_name_query_np[i].split('.')[0]

        query_sample_res = [img_query_id]
        for j, jth_idx in enumerate(topk_idx):
            img_key_id = img_name_db_np[jth_idx].split('.')[0]
            query_sample_res.append(img_key_id)
        query_res.append(query_sample_res)

    query_res = pd.DataFrame(query_res)
    query_res.rename(columns={0: 'Query'}, inplace=True)
    query_res.to_csv(args.output_result_path, index=False, header=False)
    print('Done')
Ejemplo n.º 14
0
    def train(self, generator, discriminator, train_data_dir):
        
        # convert device
        generator.to(self.device)
        discriminator.to(self.device)

        # create a shadow copy of the generator
        generator_shadow = copy.deepcopy(generator)

        # initialize the gen_shadow weights equal to the
        # weights of gen
        update_average(generator_shadow, generator, beta=0)
        generator_shadow.train()
        
        optimizer_G = Adam(generator.parameters(),
                            lr=self.learning_rate,
                            betas=self.betas)
        optimizer_D = Adam(discriminator.parameters(),
                            lr=self.learning_rate, 
                            betas=self.betas)

        image_size = 2 ** (self.depth + 1)
        print("Construct dataset")
        train_dataset = ImageDataset(train_data_dir, image_size)
        print("Construct optimizers")

        now = datetime.datetime.now()
        checkpoint_dir = os.path.join(self.checkpoint_root, f"{now.strftime('%Y%m%d_%H%M')}-progan")
        sample_dir = os.path.join(checkpoint_dir, "sample")
        os.makedirs(checkpoint_dir, exist_ok=True)
        os.makedirs(sample_dir, exist_ok=True)

        # training roop
        loss = losses.WGAN_GP(discriminator)
        print("Training Starts.")
        start_time = time.time()
        
        logger = Logger()
        iterations = 0
        fixed_input = torch.randn(16, self.latent_dim).to(self.device)
        fixed_input /= torch.sqrt(torch.sum(fixed_input*fixed_input + 1e-8, dim=1, keepdim=True))

        for current_depth in range(self.depth):
            current_res = 2 ** (current_depth + 1)
            print(f"Currently working on Depth: {current_depth}")
            current_res = np.power(2, current_depth + 2)
            print(f"Current resolution: {current_res} x {current_res}")

            train_dataloader = DataLoader(train_dataset, batch_size=self.batch_sizes[current_depth],
                                        shuffle=True, drop_last=True, num_workers=4)
            
            ticker = 1
            
            for epoch in range(self.num_epochs[current_depth]):
                print(f"epoch {epoch}")
                total_batches = len(train_dataloader)
                fader_point = int((self.fade_ins[current_depth] / 100)
                                * self.num_epochs[current_depth] * total_batches)
                

                for i, batch in enumerate(train_dataloader):
                    alpha = ticker / fader_point if ticker <= fader_point else 1

                    images = batch.to(self.device)

                    gan_input = torch.randn(images.shape[0], self.latent_dim).to(self.device)
                    gan_input /= torch.sqrt(torch.sum(gan_input*gan_input + 1e-8, dim=1, keepdim=True))

                    real_samples = progressive_downsampling(images, current_depth, self.depth, alpha)
                        # generate a batch of samples
                    fake_samples = generator(gan_input, current_depth, alpha).detach()
                    d_loss = loss.dis_loss(real_samples, fake_samples, current_depth, alpha)

                    # optimize discriminator
                    optimizer_D.zero_grad()
                    d_loss.backward()
                    optimizer_D.step()

                    d_loss_val = d_loss.item()

                    gan_input = torch.randn(images.shape[0], self.latent_dim).to(self.device)
                    gan_input /= torch.sqrt(torch.sum(gan_input*gan_input + 1e-8, dim=1, keepdim=True))
                    fake_samples = generator(gan_input, current_depth, alpha)
                    g_loss = loss.gen_loss(real_samples, fake_samples, current_depth, alpha)
                    
                    optimizer_G.zero_grad()
                    g_loss.backward()
                    optimizer_G.step()

                    update_average(generator_shadow, generator, beta=0.999)
                    
                    g_loss_val = g_loss.item()

                    elapsed = time.time() - start_time

                    logger.log(depth=current_depth,
                               epoch=epoch,
                               i=i,
                               iterations=iterations,
                               loss_G=g_loss_val,
                               loss_D=d_loss_val,
                               time=elapsed)
                    
                    ticker += 1
                    iterations += 1

                logger.output_log(f"{checkpoint_dir}/log.csv")
                
                generator.eval()
                with torch.no_grad():
                    sample_images = generator(fixed_input, current_depth, alpha)
                save_samples(sample_images, current_depth, sample_dir, current_depth, epoch, image_size)
                generator.train()
                if current_depth == self.depth:
                    torch.save(generator.state_dict(), f"{checkpoint_dir}/{current_depth}_{epoch}_generator.pth")
Ejemplo n.º 15
0
    # Check for required input
    option_, _ = parser.parse_known_args()
    print(option_)
    if int(option_.R0) + int(option_.R20) + int(option_.Final) == 0:
        assert False, 'Please activate one of the [R0, R20, Final] options using --[R0]'
    elif int(option_.R0) + int(option_.R20) + int(option_.Final) > 1:
        assert False, 'Please activate only ONE of the [R0, R20, Final] options'

    if option_.depthNet == 1:
        from structuredrl.models import DepthNet

    # Setting each networks receptive field and setting the patch estimation resolution to twice the receptive
    # field size to speed up the local refinement as described in the section 6 of the main paper.
    if option_.depthNet == 0:
        option_.net_receptive_field_size = 384
        option_.patch_netsize = 2*option_.net_receptive_field_size
    elif option_.depthNet == 1:
        option_.net_receptive_field_size = 448
        option_.patch_netsize = 2*option_.net_receptive_field_size
    elif option_.depthNet == 2:
        option_.net_receptive_field_size = 448
        option_.patch_netsize = 2 * option_.net_receptive_field_size
    else:
        assert False, 'depthNet can only be 0,1 or 2'

    # Create dataset from input images
    dataset_ = ImageDataset(option_.data_dir, 'test')

    # Run pipeline
    run(dataset_, option_)
Ejemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser(description='Image Classification.')
    parser.add_argument('--model-name', type=str, default='resnet50')
    parser.add_argument(
        '--checkpoint_path',
        type=str,
        default='/home/ubuntu/hxh/tl-ssl/tl/tl_finetune',
        help=
        'Path to save checkpoint, only the model with highest top1 acc will be saved,'
        'And the records will also be writen in the folder')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='Batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-2,
                        help='Initial learning rate')
    parser.add_argument('--epoch',
                        type=int,
                        default=100,
                        help='Maximum training epoch')
    parser.add_argument('--start-epoch',
                        type=int,
                        default=0,
                        help='Start training epoch')

    parser.add_argument('--root-dir',
                        type=str,
                        default='../data/Caltech256',
                        help='path to the image folder')
    parser.add_argument('--train-file',
                        type=str,
                        default='../dataset/caltech256_train_0.csv',
                        help='path to the train csv file')
    parser.add_argument('--val-file',
                        type=str,
                        default='../dataset/caltech256_val_0.csv',
                        help='path to the val csv file')
    parser.add_argument('--test-file',
                        type=str,
                        default='../dataset/caltech256_test_0.csv',
                        help='path to the test csv file')
    # parser.add_argument('--test-dir', type=str, default='xxx/test',
    #                     help='path to the train folder, each class has a single folder')
    parser.add_argument('--cos',
                        type=bool,
                        default=False,
                        help='Use cos learning rate sheduler')
    parser.add_argument('--schedule',
                        default=[20, 40],
                        nargs='*',
                        type=int,
                        help='learning rate schedule (when to drop lr by 10x)')

    parser.add_argument(
        '--pretrained',
        type=str,
        default='None',
        help='Load which pretrained model, '
        'None : Do not load any weight, random initialize'
        'Imagenet : official Imagenet pretrained model,'
        'MoCo : Transfer model from Moco, path in $transfer-resume$'
        'Transfer : Transfer model from Supervised pretrained, path in $transfer-resume$'
        'Resume : Load checkpoint for corrupted training process, path in $resume$'
    )
    parser.add_argument(
        '--transfer-resume',
        type=str,
        default='/home/ubuntu/hxh/tl-ssl/tl/tl_pretrain/best.pth.tar',
        help='Path to load transfering pretrained model')
    parser.add_argument(
        '--resume',
        type=str,
        default='/home/ubuntu/hxh/tl-ssl/tl/tl_finetune/best.pth.tar',
        help='Path to resume a checkpoint')
    parser.add_argument('--num-class',
                        type=int,
                        default=257,
                        help='Number of class for the classification')
    parser.add_argument('--PRINT-INTERVAL',
                        type=int,
                        default=20,
                        help='Number of batch to print the loss')
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device {}".format(device))
    # Create checkpoint file

    if os.path.exists(args.checkpoint_path) == False:
        os.makedirs(args.checkpoint_path)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    test_trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])
    Caltech = Caltech256(root=args.root_dir,
                         train_file=args.train_file,
                         val_file=args.val_file,
                         test_file=args.test_file,
                         download=True)

    trainset = ImageDataset(images=Caltech.train_images,
                            labels=Caltech.train_label,
                            transforms=transforms.Compose([
                                transforms.Resize(256),
                                transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                auto.ImageNetPolicy(),
                                transforms.ToTensor(), normalize
                            ]))
    valset = ImageDataset(images=Caltech.val_images,
                          labels=Caltech.val_label,
                          transforms=test_trans)

    testset = ImageDataset(images=Caltech.test_images,
                           labels=Caltech.test_label,
                           transforms=test_trans)

    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              num_workers=8,
                              shuffle=True,
                              pin_memory=True)
    val_loader = DataLoader(valset,
                            batch_size=args.batch_size,
                            num_workers=8,
                            pin_memory=True)
    test_loader = DataLoader(testset,
                             batch_size=args.batch_size,
                             num_workers=8,
                             pin_memory=True)

    # Define Loss Function
    LOSS_FUNC = nn.CrossEntropyLoss().to(device)

    print(args.model_name)

    if args.pretrained == 'Imagenet':
        # ImageNet supervised pretrained model
        print('ImageNet supervised pretrained model')
        model = MODEL_DICT[args.model_name](num_classes=args.num_class,
                                            pretrained=True)

    elif args.pretrained == 'MoCo':
        # load weight from transfering model from moco
        print('Load weight from transfering model from moco')
        model = MODEL_DICT[args.model_name](num_classes=args.num_class,
                                            pretrained=False)

        if args.transfer_resume:
            if os.path.isfile(args.transfer_resume):
                print("=> loading checkpoint '{}'".format(
                    args.transfer_resume))
                checkpoint = torch.load(args.transfer_resume,
                                        map_location="cpu")

                # rename moco pre-trained keys
                state_dict = checkpoint['state_dict']
                for k in list(state_dict.keys()):
                    # retain only encoder_q up to before the embedding layer
                    if k.startswith('module.encoder_q') and not k.startswith(
                            'module.encoder_q.fc'):
                        # remove prefix
                        state_dict[
                            k[len("module.encoder_q."):]] = state_dict[k]
                    # delete renamed or unused k
                    del state_dict[k]

                msg = model.load_state_dict(state_dict, strict=False)
                assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

                print("=> loaded pre-trained model '{}'".format(
                    args.transfer_resume))
            else:
                print("=> no checkpoint found at '{}'".format(
                    args.transfer_resume))

        # init the fc layer
        model.fc.weight.data.normal_(mean=0.0, std=0.01)
        model.fc.bias.data.zero_()

    elif args.pretrained == 'Transfer':
        # load weight from transfering model from supervised pretraining

        model = MODEL_DICT[args.model_name](num_classes=args.num_class,
                                            pretrained=False)
        print('Load weight from transfering model from supervised pretraining')

        if args.transfer_resume:
            if os.path.isfile(args.transfer_resume):
                print("=> loading checkpoint '{}'".format(
                    args.transfer_resume))

                checkpoint = torch.load(args.transfer_resume)
                state_dict = checkpoint['state_dict']
                for k in list(state_dict.keys()):
                    # retain only encoder_q up to before the embedding layer
                    print(k)
                    if k.startswith("module.fc.") or k.startswith("fc."):
                        del state_dict[k]
                    elif k.startswith('module.'):
                        # remove prefix
                        state_dict[k[len("module."):]] = state_dict[k]
                        # delete renamed or unused k
                        del state_dict[k]
                msg = model.load_state_dict(state_dict, strict=False)
                assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.transfer_resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(
                    args.transfer_resume))

        # init the fc layer
        model.fc.weight.data.normal_(mean=0.0, std=0.01)
        model.fc.bias.data.zero_()
    else:
        # Random Initialize
        print('Random Initialize')
        model = MODEL_DICT[args.model_name](num_classes=args.num_class,
                                            pretrained=False)

    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
    model = model.to(device)
    # Optimizer and learning rate scheduler

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    sheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    if args.pretrained == 'Resume':
        # load weight from checkpoint
        print('Load weight from checkpoint {}'.format(args.resume))
        load_resume(args, model, optimizer, args.resume)

    args.start_epoch = 0
    metric = []
    # acc1, acc5, confusion_matrix, val_loss, aucs = test(model, test_loader, args.num_class, LOSS_FUNC, device)
    for epoch in range(args.start_epoch, args.epoch):
        # adjust_learning_rate(optimizer, epoch, args)
        train_loss = train(model, train_loader, optimizer, args.PRINT_INTERVAL,
                           epoch, args, LOSS_FUNC, device)
        acc1, acc5, confusion_matrix, val_loss, aucs = test(
            model, val_loader, args.num_class, LOSS_FUNC, device)
        metric.append(acc1)
        sheduler.step()
        # Save train/val loss, acc1, acc5, confusion matrix(F1, recall, precision), AUCs
        record = {
            'epoch': epoch + 1,
            'train loss': train_loss,
            'val loss': val_loss,
            'acc1': acc1,
            'acc5': acc5,
            'confusion matrix': confusion_matrix,
            'AUCs': aucs
        }
        torch.save(
            record,
            os.path.join(args.checkpoint_path,
                         'recordEpoch{}.pth.tar'.format(epoch)))
        # Only save the model with highest top1 acc
        if np.max(metric) == acc1:
            checkpoint = {
                'epoch': epoch + 1,
                'arch': args.model_name,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpoint,
                       os.path.join(args.checkpoint_path, 'best.pth.tar'))
            print("Model Saved")

    print('...........Testing..........')
    load_resume(args, model, optimizer,
                './checkpoint/caltech_finetune/best.pth.tar')
    acc1, acc5, confusion_matrix, val_loss, aucs = test(
        model, test_loader, args.num_class, LOSS_FUNC, device)
Ejemplo n.º 17
0
                               betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(Net_D.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))

transforms_ = [
    torchvision.transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
    torchvision.transforms.RandomCrop((opt.img_height, opt.img_width)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

# Training data loader
dataloader = DataLoader(ImageDataset(
    "D:/project/赵老师的IDEA/CycleGAN/input/apple2orange",
    transforms_=transforms_,
    unaligned=True),
                        batch_size=opt.batchsize,
                        shuffle=True)  #, num_workers=opt.n_cpu)

FloatTensor = torch.cuda.FloatTensor
patch = (1, opt.img_height // 2**4, opt.img_width // 2**4)
for epoch in range(opt.startepoch, 100):
    for i, batch in enumerate(dataloader):
        real_A = Variable(batch['B'].type(FloatTensor))
        real_B = Variable(batch['A'].type(FloatTensor))
        valid = Variable(FloatTensor(np.ones((opt.batchsize, *patch))),
                         requires_grad=False)
        fake = Variable(FloatTensor(np.zeros((opt.batchsize, *patch))),
                        requires_grad=False)
        optimizer_G.zero_grad()
Ejemplo n.º 18
0
    print('Loading dataset...')
    SIZE = 224
    data_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    def collate_fn(data):
        img_names, imgs = list(zip(*data))
        imgs = torch.stack(imgs)
        return img_names, imgs

    BATCH_SIZE = 4
    dataset = ImageDataset(dataset_path=Path(args.dataset),
                           size=SIZE,
                           istrain=True,
                           transforms=data_transforms)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=6,
                                             pin_memory=True,
                                             collate_fn=collate_fn)
    print('Done')

    # Acquire the database features.
    if model_name.startswith('efficientnet'):
        valid_layer = [2, 3]
        vec_length = np.array(embed.out_channels)[valid_layer].sum()
        img_name_np = []
        z_att_np = np.zeros([len(dataset), vec_length], dtype='float32')
Ejemplo n.º 19
0
def evaluate():
    parser = argparse.ArgumentParser(description='Image Classification.')
    parser.add_argument('--model-name', type=str, default='resnet50')

    parser.add_argument('--batch-size',
                        type=int,
                        default=16,
                        help='Batch size')

    parser.add_argument('--root-dir',
                        type=str,
                        default='../data/Caltech256',
                        help='path to the image folder')
    parser.add_argument('--train-file',
                        type=str,
                        default='../dataset/caltech256_train_0.csv',
                        help='path to the train csv file')
    parser.add_argument('--val-file',
                        type=str,
                        default='../dataset/caltech256_val_0.csv',
                        help='path to the val csv file')
    parser.add_argument('--test-file',
                        type=str,
                        default='../dataset/caltech256_test_0.csv',
                        help='path to the test csv file')

    parser.add_argument('--resume',
                        type=str,
                        default='./checkpoint/SUN_best.pth.tar',
                        help='Path to resume a checkpoint')
    parser.add_argument('--num-class',
                        type=int,
                        default=257,
                        help='Number of class for the classification')

    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device {}".format(device))
    # Create checkpoint file

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    test_trans = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(), normalize])

    Caltech = Caltech256(root=args.root_dir,
                         train_file=args.train_file,
                         val_file=args.val_file,
                         test_file=args.test_file,
                         download=True)

    testset = ImageDataset(images=Caltech.test_images,
                           labels=Caltech.test_label,
                           transforms=test_trans)

    test_loader = DataLoader(testset, batch_size=args.batch_size)

    print(args.model_name)
    # LOSS_FUNC = LabelSmoothSoftmaxCE()
    LOSS_FUNC = nn.CrossEntropyLoss()
    model = MODEL_DICT[args.model_name](num_classes=args.num_class)

    # if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model).to(device)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))

            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print('...........Testing..........')
    acc1, acc5, confusion_matrix, val_loss, aucs = test(
        model, test_loader, args.num_class, LOSS_FUNC, device)
Ejemplo n.º 20
0
    def train(self):
        num_channels = self.config.NUM_CHANNELS
        use_cuda = self.config.USE_CUDA
        lr = self.config.LEARNING_RATE

        # Networks
        netG_A2B = Generator(num_channels)
        netG_B2A = Generator(num_channels)
        netD_A = Discriminator(num_channels)
        netD_B = Discriminator(num_channels)

        #netG_A2B = Generator_BN(num_channels)
        #netG_B2A = Generator_BN(num_channels)
        #netD_A = Discriminator_BN(num_channels)
        #netD_B = Discriminator_BN(num_channels)

        if use_cuda:
            netG_A2B.cuda()
            netG_B2A.cuda()
            netD_A.cuda()
            netD_B.cuda()

        netG_A2B.apply(weights_init_normal)
        netG_B2A.apply(weights_init_normal)
        netD_A.apply(weights_init_normal)
        netD_B.apply(weights_init_normal)

        criterion_GAN = torch.nn.BCELoss()
        criterion_cycle = torch.nn.L1Loss()
        criterion_identity = torch.nn.L1Loss()

        optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                       lr=lr, betas=(0.5, 0.999))
        optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)

        # Inputs & targets memory allocation
        #Tensor = LongTensor if use_cuda else torch.Tensor
        batch_size = self.config.BATCH_SIZE
        height, width, channels = self.config.INPUT_SHAPE

        input_A = FloatTensor(batch_size, channels, height, width)
        input_B = FloatTensor(batch_size, channels, height, width)
        target_real = Variable(FloatTensor(batch_size).fill_(1.0), requires_grad=False)
        target_fake = Variable(FloatTensor(batch_size).fill_(0.0), requires_grad=False)

        fake_A_buffer = ReplayBuffer()
        fake_B_buffer = ReplayBuffer()

        transforms_ = [transforms.RandomCrop((height, width)),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

        dataloader = DataLoader(ImageDataset(self.config.DATA_DIR, self.config.DATASET_A, self.config.DATASET_B,
                                             transforms_=transforms_, unaligned=True),
                                             batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
        # Loss plot
        logger = Logger(self.config.EPOCH, len(dataloader))

        now = datetime.datetime.now()
        datetime_sequence = "{0}{1:02d}{2:02d}_{3:02}{4:02d}".format(str(now.year)[-2:], now.month, now.day ,
                                                                    now.hour, now.minute)

        output_name_1 = self.config.DATASET_A + "2" + self.config.DATASET_B
        output_name_2 = self.config.DATASET_B + "2" + self.config.DATASET_A

        experiment_dir = os.path.join(self.config.RESULT_DIR, datetime_sequence)

        sample_output_dir_1 = os.path.join(experiment_dir, "sample", output_name_1)
        sample_output_dir_2 = os.path.join(experiment_dir, "sample", output_name_2)
        weights_output_dir_1 = os.path.join(experiment_dir, "weights", output_name_1)
        weights_output_dir_2 = os.path.join(experiment_dir, "weights", output_name_2)
        weights_output_dir_resume = os.path.join(experiment_dir, "weights", "resume")

        os.makedirs(sample_output_dir_1, exist_ok=True)
        os.makedirs(sample_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_1, exist_ok=True)
        os.makedirs(weights_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_resume, exist_ok=True)

        counter = 0

        for epoch in range(self.config.EPOCH):
            """
            logger.loss_df.to_csv(os.path.join(experiment_dir,
                                 self.config.DATASET_A + "_"
                                 + self.config.DATASET_B + ".csv"),
                    index=False)
            """
            if epoch % 100 == 0:
                torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netG_A2B.pth'))
                torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netG_B2A.pth'))
                torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netD_A.pth'))
                torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netD_B.pth'))

            for i, batch in enumerate(dataloader):
                # Set model input
                real_A = Variable(input_A.copy_(batch['A']))
                real_B = Variable(input_B.copy_(batch['B']))

                ###### Generators A2B and B2A ######
                optimizer_G.zero_grad()

                # GAN loss
                fake_B = netG_A2B(real_A)
                pred_fake_B = netD_B(fake_B)
                loss_GAN_A2B = criterion_GAN(pred_fake_B, target_real)

                fake_A = netG_B2A(real_B)
                pred_fake_A = netD_A(fake_A)
                loss_GAN_B2A = criterion_GAN(pred_fake_A, target_real)

                # Cycle loss
                recovered_A = netG_B2A(fake_B)
                loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

                recovered_B = netG_A2B(fake_A)
                loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

                # Total loss
                loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
                loss_G.backward()

                optimizer_G.step()
                ###################################

                ###### Discriminator A ######
                optimizer_D_A.zero_grad()

                # Real loss
                pred_A = netD_A(real_A)
                loss_D_real = criterion_GAN(pred_A, target_real)

                # Fake loss
                fake_A_ = fake_A_buffer.push_and_pop(fake_A)
                pred_fake = netD_A(fake_A_.detach())
                loss_D_fake = criterion_GAN(pred_fake, target_fake)

                # Total loss
                loss_D_A = (loss_D_real + loss_D_fake) * 0.5
                loss_D_A.backward()

                optimizer_D_A.step()
                ###################################

                ###### Discriminator B ######
                optimizer_D_B.zero_grad()

                # Real loss
                pred_B = netD_B(real_B)
                loss_D_real = criterion_GAN(pred_B, target_real)

                # Fake loss
                fake_B_ = fake_B_buffer.push_and_pop(fake_B)
                pred_fake = netD_B(fake_B_.detach())
                loss_D_fake = criterion_GAN(pred_fake, target_fake)

                # Total loss
                loss_D_B = (loss_D_real + loss_D_fake) * 0.5
                loss_D_B.backward()

                optimizer_D_B.step()

                # Progress report (http://localhost:8097)
                logger.log({'loss_G': loss_G,
                            'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                            'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)},
                           images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B})

                if counter % 500 == 0:
                    real_A_sample = real_A.cpu().detach().numpy()[0]
                    pred_A_sample = fake_A.cpu().detach().numpy()[0]
                    real_B_sample = real_B.cpu().detach().numpy()[0]
                    pred_B_sample = fake_B.cpu().detach().numpy()[0]
                    combine_sample_1 = np.concatenate([real_A_sample, pred_B_sample], axis=2)
                    combine_sample_2 = np.concatenate([real_B_sample, pred_A_sample], axis=2)

                    file_1 = "{0}_{1}.jpg".format(epoch, counter)
                    output_sample_image(os.path.join(sample_output_dir_1, file_1), combine_sample_1)
                    file_2 = "{0}_{1}.jpg".format(epoch, counter)
                    output_sample_image(os.path.join(sample_output_dir_2, file_2), combine_sample_2)

                counter += 1


            # Update learning rates
            lr_scheduler_G.step()
            lr_scheduler_D_A.step()
            lr_scheduler_D_B.step()

        torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netG_A2B.pth'))
        torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netG_B2A.pth'))
        torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netD_A.pth'))
        torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netD_B.pth'))
Ejemplo n.º 21
0
#Losses
criterion_GAN = GANLoss(tensor=Tensor)
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

#Dataset loader
transform = [
    transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
    transforms.RandomCrop(opt.size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
dataloader = DataLoader(ImageDataset(opt.dataroot,
                                     transform=transform,
                                     unaligned=True),
                        batch_size=opt.batchSize,
                        shuffle=True,
                        num_workers=opt.n_cpu)

#Loss plot
logger = Logger(opt.n_epochs, len(dataloader))
###################################

######### Training ############
for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):
        #model input
        real_A = Variable(input_A.copy_(batch['A']))
        real_B = Variable(input_B.copy_(batch['B']))