コード例 #1
0
def generate_two_moons_data(n_noisy_dimensions):
    n_samples = 1000
    batch_size = 32
    X, y = datasets.make_moons(n_samples=n_samples, noise=0.1)
    X = (X).astype(np.float32)

    noise_mean = 0
    noise_var = 1.0
    noise = np.random.normal(
        loc=noise_mean,
        scale=noise_var,
        size=(n_samples, n_noisy_dimensions)
    )

    X_with_noise = np.concatenate((X, noise), axis=1)
    X_with_noise = X_with_noise.astype(np.float32)

    X_train, X_test, y_train, y_test = train_test_split(
        X_with_noise, y, test_size=0.8
    )

    X_valid, X_test, y_valid, y_test = train_test_split(X_test, y_test, test_size=0.5)

    train_dataset = BasicDataset(X_train, y_train)
    valid_dataset = BasicDataset(X_valid, y_valid)
    test_dataset = BasicDataset(X_test, y_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    valid_loader = DataLoader(valid_dataset, batch_size=y_valid.shape[0])
    test_loader = DataLoader(test_dataset, batch_size=y_test.shape[0])

    return train_loader, valid_loader, test_loader
コード例 #2
0
def generate_proposals(params, prefix, oprefix, name, dim, no_normalize=False):
    ds = BasicDataset(name=name,
                      prefix=prefix,
                      dim=dim,
                      normalize=not no_normalize)
    ds.info()

    folders = []
    for param in params:
        oprefix_i0 = osp.join(oprefix, name)
        knn_prefix_i0 = osp.join(prefix, 'knns', name)
        folder_i0, pred_labels_i0 = generate_basic_proposals(
            oprefix=oprefix_i0,
            knn_prefix=knn_prefix_i0,
            feats=ds.features,
            feat_dim=dim,
            **param)

        iter0 = param.get('iter0', True)
        if iter0:
            folders.append(folder_i0)

        iter1_params = param.get('iter1_params', [])
        for param_i1 in iter1_params:
            oprefix_i1 = osp.dirname(folder_i0)
            knn_prefix_i1 = osp.join(oprefix_i1, 'knns')
            folder_i1, _ = generate_iter_proposals(oprefix=oprefix_i1,
                                                   knn_prefix=knn_prefix_i1,
                                                   feats=ds.features,
                                                   feat_dim=dim,
                                                   sv_labels=pred_labels_i0,
                                                   sv_knn_prefix=knn_prefix_i0,
                                                   **param_i1)
            folders.append(folder_i1)

    return folders
コード例 #3
0
ファイル: test.py プロジェクト: takumi5757/cycleGAN-face
def main():
    parser = argparse.ArgumentParser(description="PyTorch implementation: CycleGAN")
    # for train
    parser.add_argument(
        "--image_size", "-i", type=int, default=256, help="input image size"
    )
    parser.add_argument(
        "--batch_size",
        "-b",
        type=int,
        default=1,
        help="Number of images in each mini-batch",
    )
    parser.add_argument("--epoch", "-e", type=int, default=200, help="Number of epochs")
    parser.add_argument(
        "--epoch_decay",
        "-ed",
        type=int,
        default=100,
        help="Number of epochs to start decaying learning rate to zero",
    )
    parser.add_argument(
        "--beta1", type=float, default=0.5, help="momentum term of adam"
    )
    parser.add_argument("--lr", type=float, default=0.0002, help="learning rate")
    parser.add_argument(
        "--pool_size",
        type=int,
        default=50,
        help="for discriminator: the size of image buffer that stores previously generated images",
    )
    parser.add_argument(
        "--lambda_cycle",
        type=float,
        default=10.0,
        help="Assumptive weight of cycle consistency loss",
    )
    parser.add_argument(
        "--gpu",
        "-g",
        type=int,
        default=-1,
        help="GPU ID (negative value indicates CPU)",
    )
    # for save and load
    parser.add_argument(
        "--sample_frequecy",
        "-sf",
        type=int,
        default=1000,
        help="Frequency of taking a sample",
    )
    parser.add_argument(
        "--checkpoint_frequecy",
        "-cf",
        type=int,
        default=1,
        help="Frequency of taking a checkpoint",
    )
    parser.add_argument("--data_name", "-d", default="horse2zebra", help="Dataset name")
    parser.add_argument(
        "--out", "-o", default="result/", help="Directory to output the result"
    )
    parser.add_argument(
        "--log_dir", "-l", default="logs/", help="Directory to output the log"
    )
    parser.add_argument("--model", "-m", help="Model name")
    args = parser.parse_args()

    # set GPU or CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # set depth of resnet
    if args.image_size == 128:
        res_block = 6
    else:
        res_block = 9

    # set models
    G_A2B = Generator(3, res_block).to(device)
    G_B2A = Generator(3, res_block).to(device)
    D_A = Discriminator(3).to(device)
    D_B = Discriminator(3).to(device)

    # data pararell
    # if device == 'cuda':
    #     G_A2B = torch.nn.DataParallel(G_A2B)
    #     G_B2A = torch.nn.DataParallel(G_B2A)
    #     D_A = torch.nn.DataParallel(D_A)
    #     D_B = torch.nn.DataParallel(D_B)
    #     torch.backends.cudnn.benchmark=True

    # load parameters
    G_A2B.load_state_dict(
        torch.load("models/" + args.model + "/G_A2B/" + str(args.epoch - 1) + ".pth")
    )
    G_B2A.load_state_dict(
        torch.load("models/" + args.model + "/G_B2A/" + str(args.epoch - 1) + ".pth")
    )
    D_A.load_state_dict(
        torch.load("models/" + args.model + "/D_A/" + str(args.epoch - 1) + ".pth")
    )
    D_B.load_state_dict(
        torch.load("models/" + args.model + "/D_B/" + str(args.epoch - 1) + ".pth")
    )

    test_dataset = BasicDataset(args.data_name, args.image_size, is_train=False)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0
    )

    with torch.no_grad():
        if not os.path.exists("result/" + args.model):
            os.makedirs("result/" + args.model)
        for i, data in enumerate(test_loader):
            real_A = data["A"].to(device)
            real_B = data["B"].to(device)
            trans_B, trans_A = G_B2A(real_B), G_A2B(real_A)
            rec_A, rec_B = G_B2A(trans_A), G_A2B(trans_B)

            image_A = torch.cat((real_A, trans_A, rec_A), 0)
            image_B = torch.cat((real_B, trans_B, rec_B), 0)
            save_image(
                image_A,
                "result/" + args.model + "/A_" + str(i) + ".png",
                nrow=3,
                normalize=True,
            )
            save_image(
                image_B,
                "result/" + args.model + "/B_" + str(i) + ".png",
                nrow=3,
                normalize=True,
            )

            sys.stdout.write(f"\r[Number {i+1}/{len(test_loader)}]")
コード例 #4
0
    # output cluster proposals
    ofolder_proposals = osp.join(ofolder, 'proposals')
    if is_save_proposals:
        print('saving cluster proposals to {}'.format(ofolder_proposals))
        if not osp.exists(ofolder_proposals):
            os.makedirs(ofolder_proposals)
        save_proposals(clusters, knns, ofolder=ofolder_proposals, force=force)

    return ofolder_proposals, ofn_pred_labels


if __name__ == '__main__':
    args = parse_args()

    ds = BasicDataset(name=args.name,
                      prefix=args.prefix,
                      dim=args.dim,
                      normalize=not args.no_normalize)
    ds.info()

    generate_basic_proposals(osp.join(args.oprefix, args.name),
                             osp.join(args.prefix, 'knns', args.name),
                             ds.features,
                             args.dim,
                             args.knn_method,
                             args.k,
                             args.th_knn,
                             args.th_step,
                             args.minsz,
                             args.maxsz,
                             is_rebuild=args.is_rebuild,
                             is_save_proposals=args.is_save_proposals,
コード例 #5
0

def load_GPUS(model, model_path, kwargs):
    state_dict = torch.load(model_path, **kwargs)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict['net'].items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    return model


if __name__ == "__main__":
    dir_checkpoint = 'checkpoints/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net = Unet(n_channels=3, n_classes=8, bilinear=True)
    net.to(device=device)
    model = torch.load(dir_checkpoint + 'best_score_model_unet.pth')
    net.load_state_dict(model['net'])
    #net = load_GPUS(net, dir_checkpoint + 'student_net.pth', kwargs)
    sate_dataset_val = BasicDataset("./data/val.lst")
    eval_dataloader = DataLoader(sate_dataset_val,
                                 batch_size=32,
                                 shuffle=True,
                                 num_workers=5,
                                 drop_last=True)
    print("begin")
    eval_net(net, eval_dataloader, device)
コード例 #6
0
def test_gcn_v(model, cfg, logger):
    for k, v in cfg.model['kwargs'].items():
        setattr(cfg.test_data, k, v)
    dataset = build_dataset(cfg.model['type'], cfg.test_data)

    folder = '{}_gcnv_k_{}_th_{}'.format(cfg.test_name, cfg.knn, cfg.th_sim)
    oprefix = osp.join(cfg.work_dir, folder)
    oname = osp.basename(rm_suffix(cfg.load_from))
    opath_pred_confs = osp.join(oprefix, 'pred_confs', '{}.npz'.format(oname))

    if osp.isfile(opath_pred_confs) and not cfg.force:
        data = np.load(opath_pred_confs)
        pred_confs = data['pred_confs']
        inst_num = data['inst_num']
        if inst_num != dataset.inst_num:
            logger.warn(
                'instance number in {} is different from dataset: {} vs {}'.
                format(opath_pred_confs, inst_num, len(dataset)))
    else:
        pred_confs, gcn_feat = test(model, dataset, cfg, logger)
        inst_num = dataset.inst_num

    logger.info('pred_confs: mean({:.4f}). max({:.4f}), min({:.4f})'.format(
        pred_confs.mean(), pred_confs.max(), pred_confs.min()))

    logger.info('Convert to cluster')
    with Timer('Predition to peaks'):
        pred_dist2peak, pred_peaks = confidence_to_peaks(
            dataset.dists, dataset.nbrs, pred_confs, cfg.max_conn)

    if not dataset.ignore_label and cfg.eval_interim:
        # evaluate the intermediate results
        for i in range(cfg.max_conn):
            num = len(dataset.peaks)
            pred_peaks_i = np.arange(num)
            peaks_i = np.arange(num)
            for j in range(num):
                if len(pred_peaks[j]) > i:
                    pred_peaks_i[j] = pred_peaks[j][i]
                if len(dataset.peaks[j]) > i:
                    peaks_i[j] = dataset.peaks[j][i]
            acc = accuracy(pred_peaks_i, peaks_i)
            logger.info('[{}-th conn] accuracy of peak match: {:.4f}'.format(
                i + 1, acc))
            acc = 0.
            for idx, peak in enumerate(pred_peaks_i):
                acc += int(dataset.idx2lb[peak] == dataset.idx2lb[idx])
            acc /= len(pred_peaks_i)
            logger.info(
                '[{}-th conn] accuracy of peak label match: {:.4f}'.format(
                    i + 1, acc))

    with Timer('Peaks to clusters (th_cut={})'.format(cfg.tau_0)):
        pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau_0,
                                      inst_num)

    if cfg.save_output:
        logger.info('save predicted confs to {}'.format(opath_pred_confs))
        mkdir_if_no_exists(opath_pred_confs)
        np.savez_compressed(opath_pred_confs,
                            pred_confs=pred_confs,
                            inst_num=inst_num)

        # save clustering results
        idx2lb = list2dict(pred_labels, ignore_value=-1)

        opath_pred_labels = osp.join(
            cfg.work_dir, folder, 'tau_{}_pred_labels.txt'.format(cfg.tau_0))
        logger.info('save predicted labels to {}'.format(opath_pred_labels))
        mkdir_if_no_exists(opath_pred_labels)
        write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

    # evaluation
    if not dataset.ignore_label:
        print('==> evaluation')
        for metric in cfg.metrics:
            evaluate(dataset.gt_labels, pred_labels, metric)

    if cfg.use_gcn_feat:
        # gcn_feat is saved to disk for GCN-E
        opath_feat = osp.join(oprefix, 'features', '{}.bin'.format(oname))
        if not osp.isfile(opath_feat) or cfg.force:
            mkdir_if_no_exists(opath_feat)
            write_feat(opath_feat, gcn_feat)

        name = rm_suffix(osp.basename(opath_feat))
        prefix = oprefix
        ds = BasicDataset(name=name,
                          prefix=prefix,
                          dim=cfg.model['kwargs']['nhid'],
                          normalize=True)
        ds.info()

        # use top embedding of GCN to rebuild the kNN graph
        with Timer('connect to higher confidence with use_gcn_feat'):
            knn_prefix = osp.join(prefix, 'knns', name)
            knns = build_knns(knn_prefix,
                              ds.features,
                              cfg.knn_method,
                              cfg.knn,
                              is_rebuild=True)
            dists, nbrs = knns2ordered_nbrs(knns)

            pred_dist2peak, pred_peaks = confidence_to_peaks(
                dists, nbrs, pred_confs, cfg.max_conn)
            pred_labels = peaks_to_labels(pred_peaks, pred_dist2peak, cfg.tau,
                                          inst_num)

        # save clustering results
        if cfg.save_output:
            oname_meta = '{}_gcn_feat'.format(name)
            opath_pred_labels = osp.join(
                oprefix, oname_meta, 'tau_{}_pred_labels.txt'.format(cfg.tau))
            mkdir_if_no_exists(opath_pred_labels)

            idx2lb = list2dict(pred_labels, ignore_value=-1)
            write_meta(opath_pred_labels, idx2lb, inst_num=inst_num)

        # evaluation

        if not dataset.ignore_label:
            print('==> evaluation')
            for metric in cfg.metrics:
                evaluate(dataset.gt_labels, pred_labels, metric)
        import json
        import os
        import pdb
        pdb.set_trace()
        img_labels = json.load(
            open(r'/home/finn/research/data/clustering_data/test_index.json',
                 'r',
                 encoding='utf-8'))
        import shutil
        output = r'/home/finn/research/data/clustering_data/mr_gcn_output'
        for label in set(pred_labels):
            if not os.path.exists(os.path.join(output, f'cluter_{label}')):
                os.mkdir(os.path.join(output, f'cluter_{label}'))
        for image in img_labels:
            shutil.copy2(
                image,
                os.path.join(
                    os.path.join(output,
                                 f'cluter_{pred_labels[img_labels[image]]}'),
                    os.path.split(image)[-1]))
コード例 #7
0
ファイル: test.py プロジェクト: DH-rgb/cycle-gan
def main():
    parser = argparse.ArgumentParser(description='PyTorch implementation: CycleGAN')
    #for train
    parser.add_argument('--image_size', '-i', type=int, default=256, help='input image size')
    parser.add_argument('--batch_size', '-b', type=int, default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=200,
                        help='Number of epochs')
    parser.add_argument('--epoch_decay', '-ed', type=int, default=100,
                        help='Number of epochs to start decaying learning rate to zero')                    
    parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
    parser.add_argument('--pool_size', type=int, default=50, help='for discriminator: the size of image buffer that stores previously generated images')
    parser.add_argument('--lambda_cycle', type=float, default=10.0, help='Assumptive weight of cycle consistency loss')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    #for save and load
    parser.add_argument('--sample_frequecy', '-sf', type=int, default=1000,
                        help='Frequency of taking a sample')
    parser.add_argument('--checkpoint_frequecy', '-cf', type=int, default=1,
                        help='Frequency of taking a checkpoint')
    parser.add_argument('--data_name', '-d', default="horse2zebra", help='Dataset name')
    parser.add_argument('--out', '-o', default='result/',
                        help='Directory to output the result')
    parser.add_argument('--log_dir', '-l', default='logs/',
                        help='Directory to output the log')
    parser.add_argument('--model', '-m', help='Model name')
    args = parser.parse_args()



    #set GPU or CPU
    if args.gpu >= 0 and torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    #set depth of resnet
    if args.image_size == 128:
        res_block=6
    else:
        res_block=9
    
    #set models
    G_A2B = Generator(3,res_block).to(device)
    G_B2A = Generator(3,res_block).to(device)
    D_A = Discriminator(3).to(device)
    D_B = Discriminator(3).to(device)

    # data pararell
    # if device == 'cuda':
    #     G_A2B = torch.nn.DataParallel(G_A2B)
    #     G_B2A = torch.nn.DataParallel(G_B2A)
    #     D_A = torch.nn.DataParallel(D_A)
    #     D_B = torch.nn.DataParallel(D_B)
    #     torch.backends.cudnn.benchmark=True

    #load parameters
    G_A2B.load_state_dict(torch.load("models/"+args.model+"/G_A2B/"+str(args.epoch-1)+".pth"))
    G_B2A.load_state_dict(torch.load("models/"+args.model+"/G_B2A/"+str(args.epoch-1)+".pth"))
    D_A.load_state_dict(torch.load("models/"+args.model+"/D_A/"+str(args.epoch-1)+".pth"))
    D_B.load_state_dict(torch.load("models/"+args.model+"/D_B/"+str(args.epoch-1)+".pth"))

    test_dataset = BasicDataset(args.data_name, args.image_size, is_train=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)

    with torch.no_grad():
        if not os.path.exists("result/" + args.model):
            os.makedirs("result/" + args.model)
        for i, data in enumerate(test_loader):
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)
            trans_B, trans_A = G_B2A(real_B), G_A2B(real_A)
            rec_A, rec_B = G_B2A(trans_A), G_A2B(trans_B)

            image_A = torch.cat((real_A, trans_A, rec_A),0)
            image_B = torch.cat((real_B, trans_B, rec_B),0)
            save_image(image_A,"result/" + args.model + "/A_" + str(i) + ".png", nrow=3, normalize=True)
            save_image(image_B,"result/" + args.model + "/B_" + str(i) + ".png", nrow=3, normalize=True)
            
            sys.stdout.write(f"\r[Number {i+1}/{len(test_loader)}]")
コード例 #8
0
def create_h5_all_processed(source_h5,
                            target,
                            MAX_NUM_BOXES=cfg['max_box_num'],
                            WRITE_CHUNK=cfg['dataloader_cfg']['batch_size']):
    ds = BasicDataset(source_h5, tokenizer, cfg)
    dl = data.DataLoader(ds,
                         shuffle=False,
                         collate_fn=BasicDataset.Collate_fn,
                         **cfg['dataloader_cfg'])

    with h5py.File(target, 'w', libver='latest') as hf:

        hf.create_group('querys')
        hf.create_group('box_poss')
        hf.create_group('box_features')
        hf.create_group('box_labels')

        querys_h5ds = hf.create_dataset(
            'querys/data',
            shape=(WRITE_CHUNK, cfg['max_query_word']),
            chunks=(1, cfg['max_query_word']),
            maxshape=(None, cfg['max_query_word']),
            #compression="lzf",
            dtype='i')
        box_poss_h5ds = hf.create_dataset(
            'box_poss/data',
            shape=(WRITE_CHUNK, cfg['max_box_num'], 5),
            chunks=(1, cfg['max_box_num'], 5),
            maxshape=(None, cfg['max_box_num'], 5),
            #compression="lzf",
            dtype='f')
        box_features_h5ds = hf.create_dataset(
            'box_feature/data',
            shape=(WRITE_CHUNK, cfg['max_box_num'], 2048),
            chunks=(1, cfg['max_box_num'], 2048),
            maxshape=(None, cfg['max_box_num'], 2048),
            #compression="lzf",
            dtype='f')
        box_labels_h5ds = hf.create_dataset(
            'box_label/data',
            shape=(WRITE_CHUNK, cfg['max_box_num'], cfg['max_class_word_num']),
            chunks=(1, cfg['max_box_num'], cfg['max_class_word_num']),
            maxshape=(None, cfg['max_box_num'], cfg['max_class_word_num']),
            #compression="lzf",
            dtype='f')
        others_h5ds = hf.create_dataset(
            'others/data',
            shape=(WRITE_CHUNK, 5),
            chunks=(1, 5),
            maxshape=(None, 5),
            #compression="lzf",
            dtype='i')

        def flush_into_ds(hf, i, query, box_pos, box_feature, box_label):
            querys_h5ds = hf.get('querys/data')
            querys_h5ds.resize(i, axis=0)
            querys_h5ds[(i - 1) // WRITE_CHUNK * WRITE_CHUNK:i, :] = query

            box_poss_h5ds = hf.get('box_poss/data')
            box_poss_h5ds.resize(i, axis=0)
            box_poss_h5ds[(i - 1) // WRITE_CHUNK *
                          WRITE_CHUNK:i, :, :] = box_pos

            box_features_h5ds = hf.get('box_feature/data')
            box_features_h5ds.resize(i, axis=0)
            box_features_h5ds[(i - 1) // WRITE_CHUNK *
                              WRITE_CHUNK:i, :] = box_feature

            box_labels_h5ds = hf.get('box_label/data')
            box_labels_h5ds.resize(i, axis=0)
            box_labels_h5ds[(i - 1) // WRITE_CHUNK * WRITE_CHUNK:i +
                            1, :, :] = box_label

            return 0

        i = 0
        for query, box_pos, box_feature, box_label, _ in tqdm(dl):
            query, box_pos, box_feature, box_label = query.numpy(
            ), box_pos.numpy(), box_feature.numpy(), box_label.numpy()
            i += query.shape[0]
            flush_into_ds(hf, i, query, box_pos, box_feature, box_label)
        print('reading others\r', end='')
        with h5py.File(source_h5, 'r', libver='latest') as h5file_source:
            others_h5ds_source = h5file_source.get('others/data')
            len_others = h5file_source.get('others/data').shape[0]
            others_h5ds.resize(len_others, axis=0)
            for i in range(len_others):
                others_h5ds[i] = others_h5ds_source[i]
        print('reading others finished!')

    return
コード例 #9
0
ファイル: train.py プロジェクト: DH-rgb/cycle-gan
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch implementation: CycleGAN')
    #for train
    parser.add_argument('--image_size',
                        '-i',
                        type=int,
                        default=256,
                        help='input image size')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=200,
                        help='Number of epochs')
    parser.add_argument(
        '--epoch_decay',
        '-ed',
        type=int,
        default=100,
        help='Number of epochs to start decaying learning rate to zero')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='momentum term of adam')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='learning rate')
    parser.add_argument(
        '--pool_size',
        type=int,
        default=50,
        help=
        'for discriminator: the size of image buffer that stores previously generated images'
    )
    parser.add_argument('--lambda_cycle',
                        type=float,
                        default=10.0,
                        help='Assumptive weight of cycle consistency loss')
    parser.add_argument('--lambda_identity',
                        type=float,
                        default=0,
                        help='Assumptive weight of identity mapping loss')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    #for save and load
    parser.add_argument('--sample_frequecy',
                        '-sf',
                        type=int,
                        default=5000,
                        help='Frequency of taking a sample')
    parser.add_argument('--checkpoint_frequecy',
                        '-cf',
                        type=int,
                        default=10,
                        help='Frequency of taking a checkpoint')
    parser.add_argument('--data_name',
                        '-d',
                        default="horse2zebra",
                        help='Dataset name')
    parser.add_argument('--out',
                        '-o',
                        default='result/',
                        help='Directory to output the result')
    parser.add_argument('--log_dir',
                        '-l',
                        default='logs/',
                        help='Directory to output the log')
    parser.add_argument('--model', '-m', help='Model name')
    args = parser.parse_args()

    #set GPU or CPU
    if args.gpu >= 0 and torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    #set depth of resnet
    if args.image_size == 128:
        res_block = 6
    else:
        res_block = 9

    #set models
    G_A2B = Generator(3, res_block).to(device)
    G_B2A = Generator(3, res_block).to(device)
    D_A = Discriminator(3).to(device)
    D_B = Discriminator(3).to(device)

    # data pararell
    # if device == 'cuda':
    #     G_A2B = torch.nn.DataParallel(G_A2B)
    #     G_B2A = torch.nn.DataParallel(G_B2A)
    #     D_A = torch.nn.DataParallel(D_A)
    #     D_B = torch.nn.DataParallel(D_B)
    #     torch.backends.cudnn.benchmark=True

    #init weights
    G_A2B.apply(init_weights)
    G_B2A.apply(init_weights)
    D_A.apply(init_weights)
    D_B.apply(init_weights)

    #set loss functions
    adv_loss = nn.MSELoss()
    cycle_loss = nn.L1Loss()
    identity_loss = nn.L1Loss()

    #set optimizers
    optimizer_G = torch.optim.Adam(chain(G_A2B.parameters(),
                                         G_B2A.parameters()),
                                   lr=args.lr,
                                   betas=(args.beta1, 0.999))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=args.lr,
                                     betas=(args.beta1, 0.999))

    scheduler_G = LambdaLR(optimizer_G, lr_lambda=loss_scheduler(args).f)
    scheduler_D_A = LambdaLR(optimizer_D_A, lr_lambda=loss_scheduler(args).f)
    scheduler_D_B = LambdaLR(optimizer_D_B, lr_lambda=loss_scheduler(args).f)

    #dataset loading
    train_dataset = BasicDataset(args.data_name,
                                 args.image_size,
                                 is_train=True)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0)

    #######################################################################################

    #train
    total_epoch = args.epoch

    fake_A_buffer = ImagePool()
    fake_B_buffer = ImagePool()

    for epoch in range(total_epoch):
        start = time.time()
        losses = [0 for i in range(6)]
        for i, data in enumerate(train_loader):
            #generate image
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)
            fake_A, fake_B = G_B2A(real_B), G_A2B(real_A)
            rec_A, rec_B = G_B2A(fake_B), G_A2B(fake_A)
            if args.lambda_identity > 0:
                iden_A, iden_B = G_B2A(real_A), G_A2B(real_B)

            #train generator
            set_requires_grad([D_A, D_B], False)
            optimizer_G.zero_grad()

            pred_fake_A = D_A(fake_A)
            loss_G_B2A = adv_loss(
                pred_fake_A,
                torch.tensor(1.0).expand_as(pred_fake_A).to(device))

            pred_fake_B = D_B(fake_B)
            loss_G_A2B = adv_loss(
                pred_fake_B,
                torch.tensor(1.0).expand_as(pred_fake_B).to(device))

            loss_cycle_A = cycle_loss(rec_A, real_A)
            loss_cycle_B = cycle_loss(rec_B, real_B)

            if args.lambda_identity > 0:
                loss_identity_A = identity_loss(iden_A, real_A)
                loss_identity_B = identity_loss(iden_B, real_B)
                loss_G = loss_G_A2B + loss_G_B2A + loss_cycle_A * args.lambda_cycle + loss_cycle_B * args.lambda_cycle + loss_identity_A * args.lambda_cycle * args.lambda_identity + loss_identity_B * args.lambda_cycle * args.lambda_identity

            else:
                loss_G = loss_G_A2B + loss_G_B2A + loss_cycle_A * args.lambda_cycle + loss_cycle_B * args.lambda_cycle

            loss_G.backward()
            optimizer_G.step()

            losses[0] += loss_G_A2B.item()
            losses[1] += loss_G_B2A.item()
            losses[2] += loss_cycle_A.item()
            losses[3] += loss_cycle_B.item()

            #train discriminator
            set_requires_grad([D_A, D_B], True)
            optimizer_D_A.zero_grad()
            pred_real_A = D_A(real_A)
            fake_A_ = fake_A_buffer.get_images(fake_A)
            pred_fake_A = D_A(fake_A_.detach())
            loss_D_A_real = adv_loss(
                pred_real_A,
                torch.tensor(1.0).expand_as(pred_real_A).to(device))
            loss_D_A_fake = adv_loss(
                pred_fake_A,
                torch.tensor(0.0).expand_as(pred_fake_A).to(device))
            loss_D_A = (loss_D_A_fake + loss_D_A_real) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()

            optimizer_D_B.zero_grad()
            pred_real_B = D_B(real_B)
            fake_B_ = fake_B_buffer.get_images(fake_B)
            pred_fake_B = D_B(fake_B_.detach())
            loss_D_B_real = adv_loss(
                pred_real_B,
                torch.tensor(1.0).expand_as(pred_real_B).to(device))
            loss_D_B_fake = adv_loss(
                pred_fake_B,
                torch.tensor(0.0).expand_as(pred_fake_B).to(device))
            loss_D_B = (loss_D_B_fake + loss_D_B_real) * 0.5
            loss_D_B.backward()
            optimizer_D_B.step()

            losses[4] += loss_D_A.item()
            losses[5] += loss_D_B.item()

            #get sample
            if (epoch * len(train_loader) + i) % args.sample_frequecy == 0:
                images_sample = torch.cat(
                    (real_A.data, fake_B.data, rec_A.data, real_B.data,
                     fake_A.data, rec_B.data), 0)
                if not os.path.exists("sample/" + args.model):
                    os.makedirs("sample/" + args.model)
                save_image(images_sample,
                           "sample/" + args.model + "/" +
                           str(epoch * len(train_loader) + i) + ".png",
                           nrow=3,
                           normalize=True)

            current_batch = epoch * len(train_loader) + i
            sys.stdout.write(
                f"\r[Epoch {epoch+1}/200] [Index {i}/{len(train_loader)}] [D_A loss: {loss_D_A.item():.4f}] [D_B loss: {loss_D_B.item():.4f}] [G loss: adv: {loss_G.item():.4f}] [lr: {scheduler_G.get_lr()}]"
            )

        #get tensorboard logs
        if not os.path.exists(args.log_dir + args.model):
            os.makedirs(args.log_dir + args.model)
        writer = SummaryWriter(args.log_dir + args.model)
        writer.add_scalar('loss_G_A2B', losses[0] / float(len(train_loader)),
                          epoch)
        writer.add_scalar('loss_D_A', losses[4] / float(len(train_loader)),
                          epoch)
        writer.add_scalar('loss_G_B2A', losses[1] / float(len(train_loader)),
                          epoch)
        writer.add_scalar('loss_D_B', losses[5] / float(len(train_loader)),
                          epoch)
        writer.add_scalar('loss_cycle_A', losses[2] / float(len(train_loader)),
                          epoch)
        writer.add_scalar('loss_cycle_B', losses[3] / float(len(train_loader)),
                          epoch)
        writer.add_scalar('learning_rate_G', np.array(scheduler_G.get_lr()),
                          epoch)
        writer.add_scalar('learning_rate_D_A',
                          np.array(scheduler_D_A.get_lr()), epoch)
        writer.add_scalar('learning_rate_D_B',
                          np.array(scheduler_D_B.get_lr()), epoch)
        sys.stdout.write(
            f"[Epoch {epoch+1}/200] [D_A loss: {losses[4]/float(len(train_loader)):.4f}] [D_B loss: {losses[5]/float(len(train_loader)):.4f}] [G adv loss: adv: {losses[0]/float(len(train_loader))+losses[1]/float(len(train_loader)):.4f}]"
        )

        #update learning rate
        scheduler_G.step()
        scheduler_D_A.step()
        scheduler_D_B.step()

        if (epoch + 1) % args.checkpoint_frequecy == 0:
            if not os.path.exists("models/" + args.model + "/G_A2B/"):
                os.makedirs("models/" + args.model + "/G_A2B/")
            if not os.path.exists("models/" + args.model + "/G_B2A/"):
                os.makedirs("models/" + args.model + "/G_B2A/")
            if not os.path.exists("models/" + args.model + "/D_A/"):
                os.makedirs("models/" + args.model + "/D_A/")
            if not os.path.exists("models/" + args.model + "/D_B/"):
                os.makedirs("models/" + args.model + "/D_B/")
            torch.save(
                G_A2B.state_dict(),
                "models/" + args.model + "/G_A2B/" + str(epoch) + ".pth")
            torch.save(
                G_B2A.state_dict(),
                "models/" + args.model + "/G_B2A/" + str(epoch) + ".pth")
            torch.save(D_A.state_dict(),
                       "models/" + args.model + "/D_A/" + str(epoch) + ".pth")
            torch.save(D_B.state_dict(),
                       "models/" + args.model + "/D_B/" + str(epoch) + ".pth")
コード例 #10
0
def train_net(unet, device, batch_size, epochs, lr, dir_checkpoint,
              checkpoint_name):
    global_step = 0
    writer = SummaryWriter(comment=checkpoint_name +
                           f'LR_{lr}_BS_{batch_size}')  #创建一个tensorboard文件
    sate_dataset_train = BasicDataset("./data/train.lst")  #读取训练集文件,数据预处理在此类中
    sate_dataset_val = BasicDataset("./data/val.lst")
    train_steps = len(sate_dataset_train)
    train_dataloader = DataLoader(sate_dataset_train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=8)  #将训练集封装成data_loader
    eval_dataloader = DataLoader(
        sate_dataset_val,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        drop_last=True)  #将验证集封装成data_loader,drop_last是将最后一个batch不足32的丢弃
    criterion = nn.CrossEntropyLoss()  #交叉熵损失函数
    #criterion = CrossEntropy() #交叉熵损失函数
    #criterion = FocalLoss()#focalloss损失函数
    #optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)#优化器
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          weight_decay=1e-8,
                          momentum=0.9)  # 优化器
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     patience=80,
                                                     factor=0.9,
                                                     min_lr=5e-5)  #学习率调整器
    #scheduler = PolyLR(optimizer, 8*100000/batch_size, power=0.9)
    epoch_val_loss = float('inf')  #为了保存最佳模型,以验证集精度为标准
    fw_iou_avg = 0
    for epoch in range(epochs):
        epochs_loss = 0  #计算每个epoch的loss
        with tqdm(total=train_steps,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for idx, batch_samples in enumerate(train_dataloader):
                batch_image, batch_mask = batch_samples[
                    "image"], batch_samples["mask"]
                batch_image = batch_image.to(device=device,
                                             dtype=torch.float32)
                logits = unet(
                    batch_image)  #torch.Size([batchsize, 8, 256, 256])
                y_true = batch_mask.to(
                    device=device,
                    dtype=torch.long)  #torch.Size([batchsize, 256, 256])
                loss = criterion(logits, y_true)
                epochs_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)  #梯度裁剪
                optimizer.step()
                pbar.update(batch_image.shape[0])  #进度条的总轮数,默认为10
                global_step += 1
                scheduler.step(loss)  # 监控量,调整学习率
                #scheduler.step()
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], global_step)
                if global_step % (train_steps // (batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                    val_loss, pixel_acc_avg, mean_iou_avg, _fw_iou_avg = eval_net(
                        net, eval_dataloader, device)
                    if fw_iou_avg < _fw_iou_avg:
                        fw_iou_avg = _fw_iou_avg
                    logging.info(
                        'Validation cross entropy: {}'.format(val_loss))
                    writer.add_scalar('Loss/test', val_loss, global_step)
                    writer.add_scalar('pixel_acc_avg', pixel_acc_avg,
                                      global_step)
                    writer.add_scalar('mean_iou_avg', mean_iou_avg,
                                      global_step)
                    writer.add_scalar('fw_iou_avg', fw_iou_avg, global_step)

        #以下将每个验证集损失保存到模型文件中,每个epoch之后取出与当前损失进行比较,当取出损失大于当前损失时,保存模型
        if os.path.exists(dir_checkpoint + checkpoint_name):  #如果已经存在模型文件
            checkpoint = torch.load(dir_checkpoint + checkpoint_name)
            print(fw_iou_avg, checkpoint['fw_iou_avg'])
            if fw_iou_avg > checkpoint['fw_iou_avg']:
                print('save!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
                state = {
                    'net': net.state_dict(),
                    'epoch_val_score': epoch_val_loss,
                    'fw_iou_avg': fw_iou_avg,
                    'epochth': epoch + 1
                }
                torch.save(state, dir_checkpoint + checkpoint_name)
                logging.info(f'checkpoint {epoch + 1} saved!')
        else:  #如果不存在模型文件
            try:
                os.mkdir(dir_checkpoint)
                logging.info('create checkpoint directory!')
            except OSError:
                logging.info('save checkpoint error!')
            state = {
                'net': net.state_dict(),
                'epoch_val_score': epoch_val_loss,
                'fw_iou_avg': fw_iou_avg,
                'epochth': epoch + 1
            }
            torch.save(state, dir_checkpoint + checkpoint_name)
            logging.info(f'checkpoint {epoch + 1} saved!')
    writer.close()
コード例 #11
0
def create_h5_all_processed(source_h5,
                            target,
                            tsv,
                            MAX_NUM_BOXES=cfg['max_box_num'],
                            WRITE_CHUNK=cfg['dataloader_cfg']['batch_size']):
    ds = BasicDataset(source_h5, tokenizer, cfg)
    dl = data.DataLoader(ds,
                         shuffle=False,
                         collate_fn=BasicDataset.Collate_fn,
                         **cfg['dataloader_cfg'])

    with h5py.File(target, 'w', libver='latest') as hf:

        hf.create_group('querys')
        hf.create_group('box_poss')
        hf.create_group('box_features')
        hf.create_group('box_labels')

        querys_h5ds = hf.create_dataset(
            'querys/data',
            shape=(WRITE_CHUNK, cfg['max_query_word']),
            chunks=(1, cfg['max_query_word']),
            maxshape=(None, cfg['max_query_word']),
            #compression="lzf",
            dtype='i')
        box_poss_h5ds = hf.create_dataset(
            'box_poss/data',
            shape=(WRITE_CHUNK, cfg['max_box_num'], 5),
            chunks=(1, cfg['max_box_num'], 5),
            maxshape=(None, cfg['max_box_num'], 5),
            #compression="lzf",
            dtype='f')
        box_features_h5ds = hf.create_dataset(
            'box_feature/data',
            shape=(WRITE_CHUNK, cfg['max_box_num'], 2048),
            chunks=(1, cfg['max_box_num'], 2048),
            maxshape=(None, cfg['max_box_num'], 2048),
            #compression="lzf",
            dtype='f')
        box_labels_h5ds = hf.create_dataset(
            'box_label/data',
            shape=(WRITE_CHUNK, cfg['max_box_num']),
            chunks=(1, cfg['max_box_num']),
            maxshape=(None, cfg['max_box_num']),
            #compression="lzf",
            dtype='i')
        others_h5ds = hf.create_dataset(
            'others/data',
            shape=(WRITE_CHUNK, 5),
            chunks=(1, 5),
            maxshape=(None, 5),
            #compression="lzf",
            dtype='i')

        def flush_into_ds(hf, i, query, box_pos, box_feature):
            querys_h5ds = hf.get('querys/data')
            querys_h5ds.resize(i, axis=0)
            querys_h5ds[(i - 1) // WRITE_CHUNK * WRITE_CHUNK:i, :] = query

            box_poss_h5ds = hf.get('box_poss/data')
            box_poss_h5ds.resize(i, axis=0)
            box_poss_h5ds[(i - 1) // WRITE_CHUNK *
                          WRITE_CHUNK:i, :, :] = box_pos

            box_features_h5ds = hf.get('box_feature/data')
            box_features_h5ds.resize(i, axis=0)
            box_features_h5ds[(i - 1) // WRITE_CHUNK *
                              WRITE_CHUNK:i, :] = box_feature

            #box_labels_h5ds = hf.get('box_label/data')
            #box_labels_h5ds.resize(i, axis=0)
            #box_labels_h5ds[(i - 1) // WRITE_CHUNK * WRITE_CHUNK:i + 1, :] = box_label

            return 0

        i = 0
        for query, box_pos, box_feature, _, _ in tqdm(dl):
            query, box_pos, box_feature = query.numpy(), box_pos.numpy(
            ), box_feature.numpy()
            i += query.shape[0]
            flush_into_ds(hf, i, query, box_pos, box_feature)
        box_labels_h5ds.resize(i, axis=0)

        print('reading labels:')
        ##################

        with open(tsv, 'r') as f:
            l = f.readline()
            l = f.readline()
            i = 0

            while l:
                w_class_label = np.zeros((cfg['max_box_num'], ))
                l = l.strip().split('\t')
                num_boxes = int(l[3])
                class_label = np.frombuffer(base64.b64decode(l[6]),
                                            dtype=np.int64).reshape(
                                                num_boxes, ) + 1
                if num_boxes > cfg['max_box_num']:
                    w_class_label[:] = class_label[:cfg['max_box_num']]
                else:
                    w_class_label[:num_boxes] = class_label

                box_labels_h5ds[i, :] = w_class_label
                if i % WRITE_CHUNK == WRITE_CHUNK - 1:
                    print('\rline {}'.format(i), end='')
                i += 1
                l = f.readline()
            if i % WRITE_CHUNK != WRITE_CHUNK - 1:
                print('\rline {}'.format(i), end='')
        print()

        ##############

        print('reading others\r', end='')
        with h5py.File(source_h5, 'r', libver='latest') as h5file_source:
            others_h5ds_source = h5file_source.get('others/data')
            len_others = h5file_source.get('others/data').shape[0]
            others_h5ds.resize(len_others, axis=0)
            for i in range(len_others):
                others_h5ds[i] = others_h5ds_source[i]
        print('reading others finished!')

    return
コード例 #12
0
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict['net'].items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    return model


correct_ratio = []
alpha = 0.5
batch_size = 32

sate_dataset_train = BasicDataset("./data/train.lst")  #读取训练集文件,数据预处理在此类�?
train_steps = len(sate_dataset_train)
sate_dataset_val = BasicDataset("./data/val.lst")
train_dataloader = DataLoader(sate_dataset_train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8)  #将训练集封装成data_loader
eval_dataloader = DataLoader(sate_dataset_val,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=8,
                             drop_last=True)  #将验证集封装�?

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
teach_model = deeplabv3plus_resnet50(num_classes=8, output_stride=16)
teach_model.to(device=device)