Exemplo n.º 1
0
def main(args):
    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    dataset = datasets.ImageFolder(args.path, transform=transform)
    sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed)
    loader = DataLoader(
        dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
    )

    model = load_model(args.checkpoint).to(DEVICE)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )
    evaluate(loader, model, args.out_path, args.sample_size)
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = 'datd/train.txt'
    images_path = '/data'
    labels_path = '/data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 3
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transforms = video_transforms.Compose([
        RandomSelectFrames(16),
        video_transforms.Resize(args.size),
        video_transforms.CenterCrop(args.size),
        volume_transforms.ClipToTensor(),
        tensor_transforms.Normalize(0.5, 0.5)
    ])

    f = open(
        '/home/shirakawa/movie/code/iVideoGAN/over16frame_list_training.txt',
        'rb')
    train_file_list = pickle.load(f)
    print(len(train_file_list))

    dataset = MITDataset(train_file_list, transform=transforms)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    #loader = DataLoader(
    #    dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
    #)
    loader = DataLoader(dataset,
                        batch_size=32 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint_vid_v2/vqvae_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 4
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = datasets.ImageFolder(args.path, transform=transform)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=128 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.load_path:
        load_state_dict = torch.load(args.load_path, map_location=device)
        model.load_state_dict(load_state_dict)
        print('successfully loaded model')

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = './data/train.txt'
    images_path = './data'
    labels_path = './data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    # Initialize generator and discriminator
    DpretrainedPath = './checkpoint/vqvae2GAN_040.pt'
    GpretrainedPath = './checkpoint/vqvae_040.pt'

    discriminator = Discriminator()
    generator = Generator()
    if os.path.exists(DpretrainedPath):
        print('Loading model weights...')
        discriminator.load_state_dict(
            torch.load(DpretrainedPath)['discriminator'])
        print('done')
    if os.path.exists(GpretrainedPath):
        print('Loading model weights...')
        generator.load_state_dict(torch.load(GpretrainedPath))
        print('done')

    discriminator = discriminator.to(device)
    generator = generator.to(device)

    if args.distributed:
        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr)
    optimizer_G = optim.Adam(generator.parameters(), lr=args.lr)
    scheduler_D = None
    scheduler_G = None
    if args.sched == "cycle":
        scheduler_D = CycleScheduler(
            optimizer_D,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

        scheduler_G = CycleScheduler(
            optimizer_G,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(41, args.epoch):
        train(i, loader, discriminator, generator, scheduler_D, scheduler_G,
              optimizer_D, optimizer_G, device)

        if dist.is_primary():
            torch.save(
                {
                    'generator': generator.state_dict(),
                    'discriminator': discriminator.state_dict(),
                    'g_optimizer': optimizer_G.state_dict(),
                    'd_optimizer': optimizer_D.state_dict(),
                },
                f'checkpoint/vqvae2GAN_{str(i + 1).zfill(3)}.pt',
            )
            if (i + 1) % n_critic == 0:
                torch.save(generator.state_dict(),
                           f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = OffsetDataset(args.path, transform=transform, offset=args.offset)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=args.bsize // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    # Load pre-trained VQVAE
    vqvae = VQVAE().to(device)
    try:
        vqvae.load_state_dict(torch.load(args.ckpt))
    except:
        print(
            "Seems the checkpoint was trained with data parallel, try loading it that way"
        )
        weights = torch.load(args.ckpt)
        renamed_weights = {}
        for key, value in weights.items():
            renamed_weights[key.replace('module.', '')] = value
        weights = renamed_weights
        vqvae.load_state_dict(weights)

    # Init offset encoder
    model = OffsetNetwork(vqvae).to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
Exemplo n.º 7
0
def main(args):
    model_config_json = open(args.config_path).read()
    print("ModelConfig:", model_config_json, file=sys.stderr, flush=True)
    model_config = VqvaeConfig.from_json(model_config_json)
    #model_config = VqVaeConfig2.from_json(model_config_json)

    args.distributed = dist.get_world_size() > 1
    if args.device == "cpu":
        device = torch.device("cpu")
    else:
        if args.distributed:
            device = torch.device("cuda", dist.get_local_rank())
            torch.cuda.set_device(device)
            print("dist: {} {}".format(dist.get_local_rank(), dist.get_rank()),
                  device,
                  file=sys.stderr,
                  flush=True)
        else:
            device = torch.device("cuda")

    transform = build_transform(args.size)

    dataset = ImageLmdbDataset(args.img_root_path, args.img_keys_path,
                               transform, args.batch_size, args.distributed,
                               int(time.time()))
    local_batch_size = args.batch_size
    if args.distributed:
        local_batch_size = local_batch_size // dist.get_world_size()
    print("local_batch_size={}".format(local_batch_size),
          file=sys.stderr,
          flush=True)
    loader = IterDataLoader(dataset,
                            batch_size=local_batch_size,
                            num_workers=1,
                            pin_memory=True)

    model = VQVAE(model_config).to(device)
    #model = VqVae2(model_config).to(device)

    trained_steps, recent_ckpt = find_recent_checkpoint(args.output_path)
    if recent_ckpt is not None:
        model.load_state_dict(
            torch.load(os.path.join(recent_ckpt, MODEL_SAVE_NAME),
                       map_location=device))
        print("load ckpt {}".format(recent_ckpt), file=sys.stderr, flush=True)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        '''
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epochs,
            momentum=None,
            warmup_proportion=0.05,
        )
        '''
        scheduler = optim.lr_scheduler.CyclicLR(optimizer=optimizer,
                                                base_lr=args.min_lr,
                                                max_lr=args.lr,
                                                step_size_up=args.cycle_step,
                                                cycle_momentum=False)
    if recent_ckpt is not None:
        if os.path.isfile(os.path.join(recent_ckpt, OPTIMIZER_SAVE_NAME)):
            optimizer.load_state_dict(
                torch.load(os.path.join(recent_ckpt, OPTIMIZER_SAVE_NAME),
                           map_location=device))
        if os.path.isfile(os.path.join(recent_ckpt, SCHEDULER_SAVE_NAME)):
            scheduler.load_state_dict(
                torch.load(os.path.join(recent_ckpt, SCHEDULER_SAVE_NAME),
                           map_location=device))

    if args.fp16:
        if not is_apex_available():
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    trainer = Trainer(args, loader, model_config, model, optimizer, scheduler,
                      device, trained_steps)
    trainer.train()
Exemplo n.º 8
0
    def run(self, args):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        transform = [transforms.ToTensor()]

        if args.normalize:
            transform.append(
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))

        transform = transforms.Compose(transform)

        dataset = datasets.ImageFolder(args.path, transform=transform)
        sampler = dist_fn.data_sampler(dataset,
                                       shuffle=True,
                                       distributed=args.distributed)
        loader = DataLoader(dataset,
                            batch_size=args.batch_size // args.n_gpu,
                            sampler=sampler,
                            num_workers=args.num_workers)

        self = self.to(device)

        if args.distributed:
            self = nn.parallel.DistributedDataParallel(
                self,
                device_ids=[dist_fn.get_local_rank()],
                output_device=dist_fn.get_local_rank())

        optimizer = args.optimizer(self.parameters(), lr=args.lr)
        scheduler = None
        if args.sched == 'cycle':
            scheduler = CycleScheduler(
                optimizer,
                args.lr,
                n_iter=len(loader) * args.epoch,
                momentum=None,
                warmup_proportion=0.05,
            )

        start = str(time())
        run_path = os.path.join('runs', start)
        sample_path = os.path.join(run_path, 'sample')
        checkpoint_path = os.path.join(run_path, 'checkpoint')
        os.mkdir(run_path)
        os.mkdir(sample_path)
        os.mkdir(checkpoint_path)

        with Progress() as progress:
            train = progress.add_task(f'epoch 1/{args.epoch}',
                                      total=args.epoch,
                                      columns='epochs')
            steps = progress.add_task('',
                                      total=len(dataset) // args.batch_size)

            for epoch in range(args.epoch):
                progress.update(steps, completed=0, refresh=True)

                for recon_loss, latent_loss, avg_mse, lr in self.train_epoch(
                        epoch, loader, optimizer, scheduler, device,
                        sample_path):
                    progress.update(
                        steps,
                        description=
                        f'mse: {recon_loss:.5f}; latent: {latent_loss:.5f}; avg mse: {avg_mse:.5f}; lr: {lr:.5f}'
                    )
                    progress.advance(steps)

                if dist_fn.is_primary():
                    torch.save(
                        self.state_dict(),
                        os.path.join(checkpoint_path,
                                     f'vqvae_{str(epoch + 1).zfill(3)}.pt'))

                progress.update(train,
                                description=f'epoch {epoch + 1}/{args.epoch}')
                progress.advance(train)