コード例 #1
0
ファイル: train.py プロジェクト: dvschultz/imaginaire
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
    make_logging_dir(cfg.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)
    current_epoch, current_iteration = trainer.load_checkpoint(
        cfg, args.checkpoint,resume=args.resume)

    # Start training.
    for epoch in range(current_epoch, cfg.max_epoch):
        print('Epoch {} ...'.format(epoch))
        if not args.single_gpu:
            train_data_loader.sampler.set_epoch(current_epoch)
        trainer.start_of_epoch(current_epoch)
        for it, data in enumerate(train_data_loader):
            data = trainer.start_of_iteration(data, current_iteration)

            for _ in range(cfg.trainer.dis_step):
                trainer.dis_update(data)
            for _ in range(cfg.trainer.gen_step):
                trainer.gen_update(data)

            current_iteration += 1
            trainer.end_of_iteration(data, current_epoch, current_iteration)
            if current_iteration >= cfg.max_iter:
                print('Done with training!!!')
                return

        current_epoch += 1
        trainer.end_of_epoch(data, current_epoch, current_iteration)
    print('Done with training!!!')
    return
コード例 #2
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)
    if not hasattr(cfg, 'inference_args'):
        cfg.inference_args = None

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel.
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    test_data_loader = get_test_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          None, test_data_loader)

    # if args.checkpoint == '':
    #     # Download pretrained weights.
    #     pretrained_weight_url = cfg.pretrained_weight
    #     if pretrained_weight_url == '':
    #         print('google link to the pretrained weight is not specified.')
    #         raise
    #     default_checkpoint_path = args.config.replace('.yaml', '.pt')
    #     args.checkpoint = get_checkpoint(
    #         default_checkpoint_path, pretrained_weight_url)
    #     print('Checkpoint downloaded to', args.checkpoint)

    # Load checkpoint.
    trainer.load_checkpoint(cfg, args.checkpoint)

    # Do inference.
    trainer.current_epoch = -1
    trainer.current_iteration = -1
    trainer.test(test_data_loader, args.output_dir, cfg.inference_args)
コード例 #3
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
    make_logging_dir(cfg.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)

    # Start evaluation.
    checkpoints = \
        sorted(glob.glob('{}/*.pt'.format(args.checkpoint_logdir)))
    for checkpoint in checkpoints:
        current_epoch, current_iteration = \
            trainer.load_checkpoint(cfg, checkpoint, resume=True)
        trainer.current_epoch = current_epoch
        trainer.current_iteration = current_iteration
        trainer.write_metrics()
    print('Done with evaluation!!!')
    return
コード例 #4
0
ファイル: wc_vid2vid.py プロジェクト: zhangsdly/imaginaire
    def _init_single_image_model(self, load_weights=True):
        r"""Load single image model, if any."""
        if self.single_image_model is None and \
                hasattr(self.gen_cfg, 'single_image_model'):
            print('Using single image model...')
            single_image_cfg = Config(self.gen_cfg.single_image_model.config)

            # Init model.
            net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
                get_model_optimizer_and_scheduler(single_image_cfg)

            # Init trainer and load checkpoint.
            trainer = get_trainer(single_image_cfg, net_G, net_D, opt_G, opt_D,
                                  sch_G, sch_D, None, None)
            if load_weights:
                print('Loading single image model checkpoint')
                single_image_ckpt = self.gen_cfg.single_image_model.checkpoint
                trainer.load_checkpoint(single_image_cfg, single_image_ckpt)
                print('Loaded single image model checkpoint')

            self.single_image_model = net_G.module
            self.single_image_model_z = None
コード例 #5
0
ファイル: cagan2.py プロジェクト: zebincai/imaginaire
        # encoder
        x_en2 = self.layer1(x)
        x_en2 = torch.cat([x_en2, x_d02], dim=1)
        x_en4 = self.layer2(x_en2)
        x_en4 = torch.cat([x_en4, x_d04], dim=1)
        x_en8 = self.layer3(x_en4)
        x_en8 = torch.cat([x_en8, x_d08], dim=1)
        x_en16 = self.layer4(x_en8)
        x_en16 = torch.cat([x_en16, x_d16], dim=1)

        # decoder
        x_de8 = self.layer5(x_en16, x_en8)
        # x_de8 = torch.cat([x_de8, x_en8], dim=1)
        x_de4 = self.layer6(x_de8, x_en4)
        # x_de4 = torch.cat([x_de4, x_en4], dim=1)
        x_de2 = self.layer7(x_de4, x_en2)
        # x_de2 = torch.cat([x_de2, x_en2], dim=1)
        out = self.layer8(x_de2, xi_yj)
        out = self.outlayer(out)
        return out


if __name__ == "__main__":
    from imaginaire.config import Config
    cfg = Config("/configs/projects/cagan/LipMPV/base_dis2_gen1.yaml")
    gen = Generator(cfg.gen, cfg.data)
    batch = torch.randn((8, 9, 256, 192))
    y = gen(batch)
    print(y.shape)
コード例 #6
0
        """
        batch_size = images.size(0)
        features = self.model(images)
        # outputs = self.classifier(features.view(batch_size, -1))
        # return outputs, features, images
        return features


class Discriminator(nn.Module):
    def __init__(self, gen_cfg, data_cfg):
        super(Discriminator, self).__init__()
        self.dis = ResDiscriminator(image_channels=6)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.dis(x)
        out = self.sigmoid(out)
        return out


if __name__ == "__main__":
    from imaginaire.config import Config
    cfg = Config(
        "D:/workspace/develop/imaginaire/configs/projects/cagan/LipMPV/base.yaml"
    )
    dis = Discriminator(cfg.dis, cfg.data)
    batch = torch.randn((8, 6, 256, 192))

    features, images = dis(batch)
    print(features.shape)
コード例 #7
0
ファイル: build_lmdb.py プロジェクト: yejees/ObjectSwap
def main():
    r""" Build lmdb for training/testing.
    Usage:
    python scripts/build_lmdb.py \
      --config configs/data_image.yaml \
      --data_root /mnt/bigdata01/datasets/test_image \
      --output_root /mnt/bigdata01/datasets/test_image/lmdb_0/ \
      --overwrite
    """
    args = parse_args()
    cfg = Config(args.config)

    # Check if output file already exists.
    if os.path.exists(args.output_root):
        if args.overwrite:
            print('Deleting existing output LMDB.')
            shutil.rmtree(args.output_root)
        else:
            print('Output root LMDB already exists. Use --overwrite. ' +
                  'Exiting...')
            return

    all_filenames, extensions = \
        create_metadata(data_root=args.data_root,
                        cfg=cfg,
                        paired=args.paired,
                        input_list=args.input_list)
    required_data_types = cfg.data.data_types

    # Build LMDB.
    os.makedirs(args.output_root)
    for data_type in required_data_types:
        data_size = 0
        print('Data type:', data_type)
        filepaths, keys = [], []
        print('>> Building file list.')

        # Get appropriate list of files.
        if args.paired:
            filenames = all_filenames
        else:
            filenames = all_filenames[data_type]

        for sequence in tqdm(filenames):
            for filename in copy.deepcopy(filenames[sequence]):
                filepath = construct_file_path(args.data_root, data_type,
                                               sequence, filename,
                                               extensions[data_type])
                key = '%s/%s' % (sequence, filename)
                filesize = check_and_add(filepath,
                                         key,
                                         filepaths,
                                         keys,
                                         remove_missing=args.remove_missing)

                # Remove file from list, if missing.
                if filesize == -1 and args.paired and args.remove_missing:
                    print('Removing %s from list' % (filename))
                    filenames[sequence].remove(filename)
                data_size += filesize

        # Remove empty sequences.
        if args.paired and args.remove_missing:
            for sequence in copy.deepcopy(all_filenames):
                if not all_filenames[sequence]:
                    all_filenames.pop(sequence)

        # Allocate size.
        data_size = max(int((1 + args.metadata_factor) * data_size), 1e9)
        print('Reserved size: %s, %dGB' % (data_type, data_size // 1e9))

        # Write LMDB to file.
        output_filepath = os.path.join(args.output_root, data_type)
        build_lmdb(filepaths, keys, output_filepath, data_size, args.large)

    # Output list of all filenames.
    if args.output_root:
        with open(args.output_root + '/all_filenames.json', 'w') as fout:
            json.dump(all_filenames, fout, indent=4)

        # Output metadata.
        with open(args.output_root + '/metadata.json', 'w') as fout:
            json.dump(extensions, fout, indent=4)
    else:
        return all_filenames, extensions
コード例 #8
0
def main():

    args = parse_args()
    cfg = Config(args.config)

    # Check if output file already exists.
    if os.path.exists(args.output_root):
        if args.overwrite:
            print('Deleting existing output LMDB.')
            shutil.rmtree(args.output_root)
        else:
            print('Output root LMDB already exists. Use --overwrite. ' +
                  'Exiting...')
            return

    # all_filenames: dictionary
    #   "images_content" -> "class01": ["image01.jpg",...], "class02": ["image01.jpg",...]
    #   "images_style"   -> "class01": ["image01.jpg",...], "class02": ["image01.jpg",...]
    all_filenames, extensions = \
        create_metadata(data_root=args.data_root,
                        cfg=cfg,
                        paired=args.paired,
                        input_list=args.input_list)
    required_data_types = cfg.data.data_types

    # Build LMDB.
    os.makedirs(args.output_root)
    for data_type in required_data_types:  # required_data_types = ['images_content', 'images_style']
        data_size = 0
        print('Data type:', data_type)
        filepaths, keys = [], []
        print('>> Building file list.')

        # Get appropriate list of files.
        if args.paired:
            filenames = all_filenames
        else:
            filenames = all_filenames[data_type]

        for sequence in tqdm(filenames):  # each class

            # append key and filepath to keys and filepaths in each class
            for filename in copy.deepcopy(filenames[sequence]):
                filepath = construct_file_path(args.data_root, data_type,
                                               sequence, filename,
                                               extensions[data_type])
                # key = '%s/%s' % (sequence, filename)
                key = os.path.join(sequence, filename)
                filesize = check_and_add(filepath,
                                         key,
                                         filepaths,
                                         keys,
                                         remove_missing=args.remove_missing)

                # Remove file from list, if missing.
                if filesize == -1 and args.paired and args.remove_missing:
                    print('Removing %s from list' % (filename))
                    filenames[sequence].remove(filename)
                data_size += filesize

        # Remove empty sequences.
        if args.paired and args.remove_missing:
            for sequence in copy.deepcopy(all_filenames):
                if not all_filenames[sequence]:
                    all_filenames.pop(sequence)

        # Allocate size.
        data_size = max(int((1 + args.metadata_factor) * data_size), 1e9)
        print('Reserved size: %s, %dGB' % (data_type, data_size // 1e9))

        # Write LMDB to file.
        output_filepath = os.path.join(args.output_root, data_type)
        build_lmdb(filepaths, keys, output_filepath, data_size, args.large)

    # Output list of all filenames.
    if args.output_root:
        with open(os.path.join(args.output_root, 'all_filenames.json'),
                  'w') as fout:
            json.dump(all_filenames, fout, indent=4)

        # Output metadata.
        with open(os.path.join(args.output_root, 'metadata.json'),
                  'w') as fout:
            json.dump(extensions, fout, indent=4)
    else:
        return all_filenames, extensions