示例#1
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    train_transform = build_transforms(config)

    train_data_loader = torch.utils.data.DataLoader(
        DualViewDataset(config.dataset_path, transform=train_transform),
        batch_size=config.train.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=config.workers,
        worker_init_fn=worker_init_fn,
    )
    train_data_loader = ChunkedDataLoader(
        train_data_loader,
        size=round(len(train_data_loader) / (config.epochs / 1)),
    )

    model = Model(backbone=config.model.backbone).to(DEVICE)
    optimizer = build_optimizer(model.parameters(), config)
    scheduler = build_scheduler(optimizer, config, len(train_data_loader))

    best_score = -float("inf")
    for epoch in range(1, config.epochs + 1):
        train_epoch(
            model,
            train_data_loader,
            optimizer,
            scheduler,
            epoch=epoch,
            config=config,
        )
        score = epoch

        if score > best_score:
            best_score = score
            torch.save(
                {
                    "model": model.backbone.state_dict(),
                    "config": config,
                },
                os.path.join(config.experiment_path, "checkpoint.pth"),
            )
            print("new best score saved: {:.4f}".format(score))
示例#2
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)
    del kwargs
    random_seed(config.seed)

    box_coder = BoxCoder(config.model.levels)

    train_transform = T.Compose([
        Resize(config.resize_size),
        RandomCrop(config.crop_size),
        RandomFlipLeftRight(),
        ApplyTo(
            "image",
            T.Compose([
                T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
                T.ToTensor(),
                T.Normalize(mean=MEAN, std=STD),
            ]),
        ),
        FilterBoxes(),
        BuildTargets(box_coder),
    ])
    eval_transform = T.Compose([
        Resize(config.resize_size),
        RandomCrop(config.crop_size),
        ApplyTo(
            "image",
            T.Compose([
                T.ToTensor(),
                T.Normalize(mean=MEAN, std=STD),
            ]),
        ),
        FilterBoxes(),
        BuildTargets(box_coder),
    ])

    if config.dataset == "coco":
        Dataset = CocoDataset
    else:
        raise AssertionError("invalid config.dataset {}".format(
            config.dataset))
    train_dataset = Dataset(config.dataset_path,
                            subset="train",
                            transform=train_transform)
    eval_dataset = Dataset(config.dataset_path,
                           subset="eval",
                           transform=eval_transform)
    class_names = train_dataset.class_names

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=config.workers,
        collate_fn=collate_fn,
        worker_init_fn=worker_init_fn,
    )
    if config.train_steps is not None:
        train_data_loader = DataLoaderSlice(train_data_loader,
                                            config.train_steps)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.eval.batch_size,
        drop_last=False,
        shuffle=True,
        num_workers=config.workers,
        collate_fn=collate_fn,
        worker_init_fn=worker_init_fn,
    )

    model = FCOS(config.model, num_classes=Dataset.num_classes)
    if config.model.freeze_bn:
        model = BatchNormFreeze(model)
    model = model.to(DEVICE)

    optimizer = build_optimizer(model.parameters(), config)

    saver = Saver({"model": model, "optimizer": optimizer})
    start_epoch = 0
    if config.restore_path is not None:
        saver.load(config.restore_path, keys=["model"])
    if os.path.exists(os.path.join(config.experiment_path, "checkpoint.pth")):
        start_epoch = saver.load(
            os.path.join(config.experiment_path, "checkpoint.pth"))

    scheduler = build_scheduler(optimizer, config, len(train_data_loader),
                                start_epoch)

    for epoch in range(start_epoch, config.train.epochs):
        train_epoch(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            data_loader=train_data_loader,
            box_coder=box_coder,
            class_names=class_names,
            epoch=epoch,
            config=config,
        )
        gc.collect()
        # eval_epoch(
        #     model=model,
        #     data_loader=eval_data_loader,
        #     box_coder=box_coder,
        #     class_names=class_names,
        #     epoch=epoch,
        #     config=config)
        gc.collect()

        saver.save(os.path.join(config.experiment_path, "checkpoint.pth"),
                   epoch=epoch + 1)
示例#3
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    # build transforms #################################################################################################
    transform_x, transform_u, eval_transform = build_transforms()

    # build datasets ###################################################################################################
    x_indices, u_indices = build_x_u_split(
        torchvision.datasets.CIFAR10(config.dataset_path,
                                     train=True,
                                     download=True),
        config.train.num_labeled,
    )

    x_dataset = torch.utils.data.Subset(
        torchvision.datasets.CIFAR10(config.dataset_path,
                                     train=True,
                                     transform=transform_x,
                                     download=True),
        x_indices,
    )
    u_dataset = UDataset(*[
        torch.utils.data.Subset(
            torchvision.datasets.CIFAR10(config.dataset_path,
                                         train=True,
                                         transform=transform_u,
                                         download=True),
            u_indices,
        ) for _ in range(2)
    ])
    eval_dataset = torchvision.datasets.CIFAR10(config.dataset_path,
                                                train=False,
                                                transform=eval_transform,
                                                download=True)

    # build data loaders ###############################################################################################
    train_data_loader = XUDataLoader(
        torch.utils.data.DataLoader(
            x_dataset,
            batch_size=config.train.batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=config.workers,
        ),
        torch.utils.data.DataLoader(
            u_dataset,
            batch_size=config.train.batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=config.workers,
        ),
    )
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.eval.batch_size,
        num_workers=config.workers)

    # build model ######################################################################################################
    model = Model(config.model, NUM_CLASSES).to(DEVICE)
    model.apply(weights_init)
    optimizer = build_optimizer(model.parameters(), config)
    scheduler = build_scheduler(optimizer, config, len(train_data_loader))
    saver = Saver({
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler,
    })
    if config.restore_path is not None:
        saver.load(config.restore_path, keys=["model"])

    for epoch in range(1, config.epochs + 1):
        train_epoch(model,
                    train_data_loader,
                    optimizer,
                    scheduler,
                    epoch=epoch,
                    config=config)
        if epoch % config.log_interval != 0:
            continue
        eval_epoch(model, eval_data_loader, epoch=epoch, config=config)
        saver.save(os.path.join(config.experiment_path,
                                "checkpoint_{}.pth".format(epoch)),
                   epoch=epoch)
示例#4
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    vocab = CharVocab()
    train_transform, eval_transform = build_transforms(vocab, config)

    train_dataset = LJ(config.dataset_path,
                       subset="train",
                       transform=train_transform)
    eval_dataset = LJ(config.dataset_path,
                      subset="test",
                      transform=eval_transform)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=BatchSampler(
            compute_sample_sizes(train_dataset),
            batch_size=config.train.batch_size,
            shuffle=True,
            drop_last=True,
        ),
        num_workers=config.workers,
        collate_fn=collate_fn,
    )
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_sampler=BatchSampler(
            compute_sample_sizes(eval_dataset),
            batch_size=config.eval.batch_size,
            shuffle=False,
            drop_last=False,
        ),
        num_workers=config.workers,
        collate_fn=collate_fn,
    )

    mean_std = torch.load("./tacotron/spectrogram_stats.pth")
    model = Model(config.model,
                  vocab_size=len(vocab),
                  sample_rate=config.sample_rate,
                  mean_std=mean_std).to(DEVICE)
    optimizer = build_optimizer(model.parameters(), config)
    scheduler = build_scheduler(optimizer, config, len(train_data_loader))
    saver = Saver({
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler,
    })
    if config.restore_path is not None:
        saver.load(config.restore_path, keys=["model"])

    for epoch in range(1, config.train.epochs + 1):
        train_epoch(model,
                    train_data_loader,
                    optimizer,
                    scheduler,
                    epoch=epoch,
                    config=config)
        eval_epoch(model, eval_data_loader, epoch=epoch, config=config)
        saver.save(os.path.join(config.experiment_path,
                                "checkpoint_{}.pth".format(epoch)),
                   epoch=epoch)
示例#5
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    eval_transform = ForEach(
        T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.5], std=[0.25]),
        ]))
    train_transform = T.Compose([
        RandomTemporalFlip(),
        RandomHorizontalFlip(),
        eval_transform,
    ])
    train_data_loader = torch.utils.data.DataLoader(
        Vimeo90kDataset(config.dataset_path,
                        subset="train",
                        transform=train_transform),
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=os.cpu_count(),
        drop_last=True,
    )
    eval_data_loader = torch.utils.data.DataLoader(
        Vimeo90kDataset(config.dataset_path,
                        subset="test",
                        transform=eval_transform),
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=os.cpu_count(),
        drop_last=False,
    )
    # vis_data_loader = torch.utils.data.DataLoader(
    #     MiddleburyDataset(
    #         os.path.join(config.dataset_path, '..', 'middlebury-eval'), video='Dumptruck', transform=eval_transform),
    #     batch_size=1,
    #     shuffle=False,
    #     num_workers=os.cpu_count(),
    #     drop_last=False)

    model = Model()
    model.to(DEVICE)
    if config.restore_path is not None:
        model.load_state_dict(torch.load(config.restore_path))

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     config.sched.steps,
                                                     gamma=0.2)

    train_writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    eval_writer = SummaryWriter(os.path.join(config.experiment_path, "eval"))

    for epoch in range(config.epochs):
        model.train()
        train_epoch(train_data_loader, model, optimizer, train_writer, epoch)

        model.eval()
        with torch.no_grad():
            eval_epoch(eval_data_loader, model, eval_writer, epoch)
            # vis_epoch(vis_data_loader, model, eval_writer, epoch)

        torch.save(model.state_dict(),
                   os.path.join(config.experiment_path, "model.pth"))
        scheduler.step()
示例#6
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)
    del config_path, kwargs

    writer = SummaryWriter(config.experiment_path)

    seed_torch(config.seed)
    env = VecEnv([lambda: build_env(config) for _ in range(config.workers)])
    if config.render:
        env = wrappers.TensorboardBatchMonitor(env, writer,
                                               config.log_interval)
    env = wrappers.Torch(env, dtype=torch.float, device=DEVICE)
    env.seed(config.seed)

    model = Model(config.model, env.observation_space, env.action_space)
    model = model.to(DEVICE)
    if config.restore_path is not None:
        model.load_state_dict(torch.load(config.restore_path))
    optimizer = build_optimizer(config.opt, model.parameters())
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.episodes)

    metrics = {
        "loss": Mean(),
        "lr": Last(),
        "eps": FPS(),
        "ep/length": Mean(),
        "ep/return": Mean(),
        "rollout/entropy": Mean(),
    }

    # ==================================================================================================================
    # training loop
    model.train()
    episode = 0
    s = env.reset()

    bar = tqdm(total=config.episodes, desc="training")
    while episode < config.episodes:
        history = History()

        with torch.no_grad():
            for _ in range(config.horizon):
                a, _ = model(s)
                a = a.sample()
                s_prime, r, d, info = env.step(a)
                history.append(state=s,
                               action=a,
                               reward=r,
                               done=d,
                               state_prime=s_prime)
                s = s_prime

                (indices, ) = torch.where(d)
                for i in indices:
                    metrics["eps"].update(1)
                    metrics["ep/length"].update(info[i]["episode"]["l"])
                    metrics["ep/return"].update(info[i]["episode"]["r"])
                    episode += 1
                    scheduler.step()
                    bar.update(1)

                    if episode % config.log_interval == 0 and episode > 0:
                        for k in metrics:
                            writer.add_scalar(k,
                                              metrics[k].compute_and_reset(),
                                              global_step=episode)
                        writer.add_histogram("rollout/action",
                                             rollout.actions,
                                             global_step=episode)
                        writer.add_histogram("rollout/reward",
                                             rollout.rewards,
                                             global_step=episode)
                        writer.add_histogram("rollout/return",
                                             returns,
                                             global_step=episode)
                        writer.add_histogram("rollout/value",
                                             values,
                                             global_step=episode)
                        writer.add_histogram("rollout/advantage",
                                             advantages,
                                             global_step=episode)

                        torch.save(
                            model.state_dict(),
                            os.path.join(config.experiment_path,
                                         "model_{}.pth".format(episode)),
                        )

        rollout = history.full_rollout()
        dist, values = model(rollout.states)
        with torch.no_grad():
            _, value_prime = model(rollout.states_prime[:, -1])
            returns = n_step_bootstrapped_return(rollout.rewards,
                                                 value_prime,
                                                 rollout.dones,
                                                 discount=config.gamma)

        # critic
        errors = returns - values
        critic_loss = errors**2

        # actor
        advantages = errors.detach()
        log_prob = dist.log_prob(rollout.actions)
        entropy = dist.entropy()

        if isinstance(env.action_space, gym.spaces.Box):
            log_prob = log_prob.sum(-1)
            entropy = entropy.sum(-1)
        assert log_prob.dim() == entropy.dim() == 2

        actor_loss = -log_prob * advantages - config.entropy_weight * entropy

        # loss
        loss = (actor_loss + 0.5 * critic_loss).mean(1)

        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))
        metrics["rollout/entropy"].update(dist.entropy().data.cpu().numpy())

        # training
        optimizer.zero_grad()
        loss.mean().backward()
        nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

    bar.close()
    env.close()
示例#7
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    gen = Gen(
        image_size=config.image_size,
        base_channels=config.gen.base_channels,
        max_channels=config.gen.max_channels,
        z_channels=config.noise_size,
    ).to(DEVICE)
    dsc = Dsc(
        image_size=config.image_size,
        base_channels=config.dsc.base_channels,
        max_channels=config.dsc.max_channels,
        batch_std=config.dsc.batch_std,
    ).to(DEVICE)
    gen_ema = copy.deepcopy(gen)
    ema = ModuleEMA(gen_ema, config.gen.ema)
    pl_ema = torch.zeros([], device=DEVICE)

    opt_gen = build_optimizer(gen.parameters(), config)
    opt_dsc = build_optimizer(dsc.parameters(), config)

    z_dist = ZDist(config.noise_size, DEVICE)
    z_fixed = z_dist(8 ** 2, truncation=1)

    if os.path.exists(os.path.join(config.experiment_path, "checkpoint.pth")):
        state = torch.load(os.path.join(config.experiment_path, "checkpoint.pth"))
        dsc.load_state_dict(state["dsc"])
        gen.load_state_dict(state["gen"])
        gen_ema.load_state_dict(state["gen_ema"])
        opt_gen.load_state_dict(state["opt_gen"])
        opt_dsc.load_state_dict(state["opt_dsc"])
        pl_ema.copy_(state["pl_ema"])
        z_fixed.copy_(state["z_fixed"])
        print("restored from checkpoint")

    dataset = build_dataset(config)
    print("dataset size: {}".format(len(dataset)))
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    data_loader = ChunkedDataLoader(data_loader, config.batches_in_epoch)

    dsc_compute_loss, gen_compute_loss = build_loss(config)

    writer = SummaryWriter(config.experiment_path)
    for epoch in range(1, config.num_epochs + 1):
        metrics = {
            "dsc/loss": Mean(),
            "gen/loss": Mean(),
        }
        dsc_logits = Concat()
        dsc_targets = Concat()

        gen.train()
        dsc.train()
        gen_ema.train()
        for batch_i, real in enumerate(
            tqdm(
                data_loader,
                desc="{}/{}".format(epoch, config.num_epochs),
                disable=config.debug,
                smoothing=0.1,
            ),
            1,
        ):
            real = real.to(DEVICE)

            # generator: train
            with zero_grad_and_step(opt_gen):
                if config.debug:
                    print("gen")
                fake, _ = gen(z_dist(config.batch_size), z_dist(config.batch_size))
                assert (
                    fake.size() == real.size()
                ), "fake size {} does not match real size {}".format(fake.size(), real.size())

                # gen fake
                logits = dsc(fake)
                loss = gen_compute_loss(logits, True)
                loss.mean().backward()
                metrics["gen/loss"].update(loss.detach())

            # generator: regularize
            if batch_i % config.gen.reg_interval == 0:
                with zero_grad_and_step(opt_gen):
                    # path length regularization
                    fake, w = gen(z_dist(config.batch_size), z_dist(config.batch_size))
                    validate_shape(w, (None, config.batch_size, config.noise_size))
                    pl_noise = torch.randn_like(fake) / math.sqrt(fake.size(2) * fake.size(3))
                    (pl_grads,) = torch.autograd.grad(
                        outputs=[(fake * pl_noise).sum()],
                        inputs=[w],
                        create_graph=True,
                        only_inputs=True,
                    )
                    pl_lengths = pl_grads.square().sum(2).mean(0).sqrt()
                    pl_mean = pl_ema.lerp(pl_lengths.mean(), config.gen.pl_decay)
                    pl_ema.copy_(pl_mean.detach())
                    pl_penalty = (pl_lengths - pl_mean).square()
                    loss_pl = pl_penalty * config.gen.pl_weight * config.gen.reg_interval
                    loss_pl.mean().backward()

            # generator: update moving average
            ema.update(gen)

            # discriminator: train
            with zero_grad_and_step(opt_dsc):
                if config.debug:
                    print("dsc")
                with torch.no_grad():
                    fake, _ = gen(z_dist(config.batch_size), z_dist(config.batch_size))
                    assert (
                        fake.size() == real.size()
                    ), "fake size {} does not match real size {}".format(fake.size(), real.size())

                # dsc real
                logits = dsc(real)
                loss = dsc_compute_loss(logits, True)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.ones_like(logits))

                # dsc fake
                logits = dsc(fake.detach())
                loss = dsc_compute_loss(logits, False)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.zeros_like(logits))

            # discriminator: regularize
            if batch_i % config.dsc.reg_interval == 0:
                with zero_grad_and_step(opt_dsc):
                    # R1 regularization
                    real = real.detach().requires_grad_(True)
                    logits = dsc(real)
                    (r1_grads,) = torch.autograd.grad(
                        outputs=[logits.sum()],
                        inputs=[real],
                        create_graph=True,
                        only_inputs=True,
                    )
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_r1 = r1_penalty * (config.dsc.r1_gamma * 0.5) * config.dsc.reg_interval
                    loss_r1.mean().backward()

            # break

        dsc.eval()
        gen.eval()
        gen_ema.eval()
        with torch.no_grad(), log_duration("visualization took {:.2f} seconds"):
            infer = Infer(gen)
            infer_ema = Infer(gen_ema)

            real = infer.postprocess(real)
            fake = infer(z_fixed)
            fake_ema = infer_ema(z_fixed)
            fake_ema_mix, fake_ema_mix_nrow = visualize_style_mixing(
                infer_ema, z_fixed[0 : 8 * 2 : 2], z_fixed[1 : 8 * 2 : 2]
            )

            fake_ema_noise, fake_ema_noise_nrow = stack_images(
                [
                    fake_ema[:8],
                    visualize_noise(infer_ema, z_fixed[:8], 128),
                ]
            )

            dsc_logits = dsc_logits.compute_and_reset().data.cpu().numpy()
            dsc_targets = dsc_targets.compute_and_reset().data.cpu().numpy()

            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            metrics["dsc/ap"] = precision_recall_auc(input=dsc_logits, target=dsc_targets)
            for k in metrics:
                writer.add_scalar(k, metrics[k], global_step=epoch)
            writer.add_figure(
                "dsc/pr_curve",
                plot_pr_curve(input=dsc_logits, target=dsc_targets),
                global_step=epoch,
            )
            writer.add_image(
                "real",
                torchvision.utils.make_grid(real, nrow=compute_nrow(real)),
                global_step=epoch,
            )
            writer.add_image(
                "fake",
                torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)),
                global_step=epoch,
            )
            writer.add_image(
                "fake_ema",
                torchvision.utils.make_grid(fake_ema, nrow=compute_nrow(fake_ema)),
                global_step=epoch,
            )
            writer.add_image(
                "fake_ema_mix",
                torchvision.utils.make_grid(fake_ema_mix, nrow=fake_ema_mix_nrow),
                global_step=epoch,
            )
            writer.add_image(
                "fake_ema_noise",
                torchvision.utils.make_grid(fake_ema_noise, nrow=fake_ema_noise_nrow * 2),
                global_step=epoch,
            )
            # break
            torch.save(
                {
                    "gen": gen.state_dict(),
                    "gen_ema": gen_ema.state_dict(),
                    "dsc": dsc.state_dict(),
                    "opt_gen": opt_gen.state_dict(),
                    "opt_dsc": opt_dsc.state_dict(),
                    "pl_ema": pl_ema,
                    "z_fixed": z_fixed,
                },
                os.path.join(config.experiment_path, "checkpoint.pth"),
            )
        # break

    writer.flush()
    writer.close()
示例#8
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)
    config.experiment_path = os.path.join(config.experiment_path, "F{}".format(config.fold))
    del kwargs
    random_seed(config.seed)

    train_transform, eval_transform = build_transforms(config)

    train_dataset = ConcatDataset(
        [
            Dataset2020KFold(
                os.path.join(config.dataset_path, "2020"),
                train=True,
                fold=config.fold,
                transform=train_transform,
            ),
            Dataset2019(os.path.join(config.dataset_path, "2019"), transform=train_transform),
        ]
    )
    eval_dataset = Dataset2020KFold(
        os.path.join(config.dataset_path, "2020"),
        train=False,
        fold=config.fold,
        transform=eval_transform,
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=config.workers,
    )
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.eval.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=config.workers,
    )

    model = Model(config.model).to(DEVICE)
    optimizer = build_optimizer(model.parameters(), config)
    scheduler = build_scheduler(optimizer, config, len(train_data_loader))
    saver = Saver(
        {
            "model": model,
            "optimizer": optimizer,
            "scheduler": scheduler,
        }
    )
    if config.restore_path is not None:
        saver.load(config.restore_path, keys=["model"])

    best_score = 0.0
    for epoch in range(1, config.train.epochs + 1):
        optimizer.train()
        train_epoch(model, train_data_loader, optimizer, scheduler, epoch=epoch, config=config)

        score = eval_epoch(model, eval_data_loader, epoch=epoch, config=config)
        # saver.save(os.path.join(config.experiment_path, 'eval', 'checkpoint_{}.pth'.format(epoch)), epoch=epoch)
        if score > best_score:
            best_score = score
            saver.save(
                os.path.join(config.experiment_path, "checkpoint_best.pth".format(epoch)),
                epoch=epoch,
            )

        optimizer.eval()
        score = eval_epoch(model, eval_data_loader, epoch=epoch, config=config, suffix="ema")
        # saver.save(os.path.join(config.experiment_path, 'eval', 'ema', 'checkpoint_{}.pth'.format(epoch)), epoch=epoch)
        if score > best_score:
            best_score = score
            saver.save(
                os.path.join(config.experiment_path, "checkpoint_best.pth".format(epoch)),
                epoch=epoch,
            )
示例#9
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    x_transform, u_transform, eval_transform = build_transforms()
    x_indices, u_indices = build_x_u_split(
        torchvision.datasets.CIFAR10(config.dataset_path,
                                     train=True,
                                     download=True),
        config.train.num_labeled,
    )

    x_dataset = torch.utils.data.Subset(
        torchvision.datasets.CIFAR10(config.dataset_path,
                                     train=True,
                                     transform=x_transform,
                                     download=True),
        x_indices,
    )
    u_dataset = torch.utils.data.Subset(
        torchvision.datasets.CIFAR10(config.dataset_path,
                                     train=True,
                                     transform=u_transform,
                                     download=True),
        u_indices,
    )
    eval_dataset = torchvision.datasets.CIFAR10(config.dataset_path,
                                                train=False,
                                                transform=eval_transform,
                                                download=True)

    train_data_loader = XUDataLoader(
        torch.utils.data.DataLoader(
            x_dataset,
            batch_size=config.train.x_batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=config.workers,
        ),
        torch.utils.data.DataLoader(
            u_dataset,
            batch_size=config.train.u_batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=config.workers,
        ),
    )
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.eval.batch_size,
        num_workers=config.workers)

    model = nn.ModuleDict({
        "teacher":
        Model(NUM_CLASSES, config.train.teacher.dropout),
        "student":
        Model(NUM_CLASSES, config.train.student.dropout),
    }).to(DEVICE)
    model.apply(weights_init)

    opt_teacher = build_optimizer(model.teacher.parameters(),
                                  config.train.teacher)
    opt_student = build_optimizer(model.student.parameters(),
                                  config.train.student)

    sched_teacher = build_scheduler(opt_teacher, config,
                                    len(train_data_loader))
    sched_student = build_scheduler(opt_student, config,
                                    len(train_data_loader))

    saver = Saver({
        "model": model,
    })
    if config.restore_path is not None:
        saver.load(config.restore_path, keys=["model"])

    for epoch in range(1, config.epochs + 1):
        train_epoch(
            model,
            train_data_loader,
            opt_teacher=opt_teacher,
            opt_student=opt_student,
            sched_teacher=sched_teacher,
            sched_student=sched_student,
            epoch=epoch,
            config=config,
        )
        if epoch % config.log_interval != 0:
            continue
        eval_epoch(model, eval_data_loader, epoch=epoch, config=config)
        saver.save(os.path.join(config.experiment_path,
                                "checkpoint_{}.pth".format(epoch)),
                   epoch=epoch)
示例#10
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)
    del config_path, kwargs

    writer = SummaryWriter(config.experiment_path)

    seed_torch(config.seed)
    env = wrappers.Batch(build_env(config))
    if config.render:
        env = wrappers.TensorboardBatchMonitor(env, writer, config.log_interval)
    env = wrappers.torch.Torch(env, device=DEVICE)
    env.seed(config.seed)

    model = Model(config.model, env.observation_space, env.action_space)
    model = model.to(DEVICE)
    if config.restore_path is not None:
        model.load_state_dict(torch.load(config.restore_path))
    optimizer = build_optimizer(config.opt, model.parameters())
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.episodes)

    metrics = {
        "loss": Mean(),
        "lr": Last(),
        "eps": FPS(),
        "ep/length": Mean(),
        "ep/return": Mean(),
        "rollout/reward": Mean(),
        "rollout/advantage": Mean(),
        "rollout/entropy": Mean(),
    }

    # training loop ====================================================================================================
    for episode in tqdm(range(config.episodes), desc="training"):
        hist = History()
        s = env.reset()
        h = model.zero_state(1)
        d = torch.ones(1, dtype=torch.bool)

        model.eval()
        with torch.no_grad():
            while True:
                trans = hist.append_transition()

                trans.record(state=s, hidden=h, done=d)
                a, _, h = model(s, h, d)
                a = a.sample()
                s, r, d, info = env.step(a)
                trans.record(action=a, reward=r)

                if d:
                    break

        # optimization =================================================================================================
        model.train()

        # build rollout
        rollout = hist.full_rollout()

        # loss
        loss = compute_loss(env, model, rollout, metrics, config)

        # metrics
        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))

        # training
        optimizer.zero_grad()
        loss.mean().backward()
        if config.grad_clip_norm is not None:
            nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
        optimizer.step()
        scheduler.step()

        metrics["eps"].update(1)
        metrics["ep/length"].update(info[0]["episode"]["l"])
        metrics["ep/return"].update(info[0]["episode"]["r"])

        if episode % config.log_interval == 0 and episode > 0:
            for k in metrics:
                writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=episode)
            torch.save(
                model.state_dict(),
                os.path.join(config.experiment_path, "model_{}.pth".format(episode)),
            )
示例#11
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    gen = Gen(
        image_size=config.image_size,
        image_channels=3,
        base_channels=config.gen.base_channels,
        z_channels=config.noise_size,
    ).to(DEVICE)
    dsc = Dsc(
        image_size=config.image_size,
        image_channels=3,
        base_channels=config.dsc.base_channels,
    ).to(DEVICE)
    # gen_ema = gen
    # gen_ema = copy.deepcopy(gen)
    # ema = EMA(gen_ema, 0.99)

    gen.train()
    dsc.train()
    # gen_ema.train()

    opt_gen = build_optimizer(gen.parameters(), config)
    opt_dsc = build_optimizer(dsc.parameters(), config)

    transform = T.Compose([
        T.Resize(config.image_size),
        T.RandomCrop(config.image_size),
        T.ToTensor(),
        T.Normalize([0.5], [0.5]),
    ])

    # dataset = torchvision.datasets.MNIST(
    #     "./data/mnist", train=True, transform=transform, download=True
    # )
    # dataset = torchvision.datasets.CelebA(
    #     "./data/celeba", split="all", transform=transform, download=True
    # )
    dataset = ImageFolderDataset("./data/wikiart/resized/landscape",
                                 transform=transform)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    noise_dist = torch.distributions.Normal(0, 1)
    dsc_compute_loss, gen_compute_loss = build_loss(config)

    writer = SummaryWriter("./log")
    for epoch in range(1, config.num_epochs + 1):
        metrics = {
            "dsc/loss": Mean(),
            "gen/loss": Mean(),
        }
        dsc_logits = Concat()
        dsc_targets = Concat()

        for batch_index, real in enumerate(
                tqdm(data_loader,
                     desc="{}/{}".format(epoch, config.num_epochs),
                     disable=config.debug)):
            real = real.to(DEVICE)

            # train discriminator
            with zero_grad_and_step(opt_dsc):
                if config.debug:
                    print("dsc")
                noise = noise_dist.sample(
                    (real.size(0), config.noise_size)).to(DEVICE)
                with torch.no_grad():
                    fake = gen(noise)
                    assert (fake.size() == real.size(
                    )), "fake size {} does not match real size {}".format(
                        fake.size(), real.size())

                # dsc real
                logits = dsc(real)
                loss = dsc_compute_loss(logits, True)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.ones_like(logits))

                # dsc fake
                logits = dsc(fake.detach())
                loss = dsc_compute_loss(logits, False)
                loss.mean().backward()
                metrics["dsc/loss"].update(loss.detach())

                dsc_logits.update(logits.detach())
                dsc_targets.update(torch.zeros_like(logits))

                if (batch_index + 1) % 8 != 0:
                    # r1
                    r1_gamma = 10
                    real = real.detach().requires_grad_(True)
                    logits = dsc(real)
                    (r1_grads, ) = torch.autograd.grad(outputs=[logits.sum()],
                                                       inputs=[real],
                                                       create_graph=True,
                                                       only_inputs=True)
                    r1_penalty = r1_grads.square().sum([1, 2, 3])
                    loss_r1 = r1_penalty * (r1_gamma / 2) * 8
                    loss_r1.mean().backward()

            if config.dsc.weight_clip is not None:
                clip_parameters(dsc, config.dsc.weight_clip)

            if (batch_index + 1) % config.dsc.num_steps != 0:
                continue

            # train generator
            with zero_grad_and_step(opt_gen):
                if config.debug:
                    print("gen")
                noise = noise_dist.sample(
                    (real.size(0), config.noise_size)).to(DEVICE)
                fake = gen(noise)
                assert (fake.size() == real.size()
                        ), "fake size {} does not match real size {}".format(
                            fake.size(), real.size())

                # gen fake
                logits = dsc(fake)
                loss = gen_compute_loss(logits, True)
                loss.mean().backward()
                metrics["gen/loss"].update(loss.detach())

                # update moving average
                # ema.update(gen)

        with torch.no_grad():
            real, fake = [(x[:4**2] * 0.5 + 0.5).clamp(0, 1)
                          for x in [real, fake]]

            dsc_logits = dsc_logits.compute_and_reset().data.cpu().numpy()
            dsc_targets = dsc_targets.compute_and_reset().data.cpu().numpy()

            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            metrics["ap"] = precision_recall_auc(input=dsc_logits,
                                                 target=dsc_targets)
            for k in metrics:
                writer.add_scalar(k, metrics[k], global_step=epoch)
            writer.add_figure(
                "dsc/pr_curve",
                plot_pr_curve(input=dsc_logits, target=dsc_targets),
                global_step=epoch,
            )
            writer.add_image(
                "real",
                torchvision.utils.make_grid(real, nrow=compute_nrow(real)),
                global_step=epoch,
            )
            writer.add_image(
                "fake",
                torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)),
                global_step=epoch,
            )
            # writer.add_image(
            #     "fake_ema",
            #     torchvision.utils.make_grid(fake, nrow=compute_nrow(fake)),
            #     global_step=epoch,
            # )

    writer.flush()
    writer.close()
示例#12
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    transform, update_transform = build_transform()

    if config.dataset == "mnist":
        dataset = torchvision.datasets.MNIST(config.dataset_path,
                                             transform=transform,
                                             download=True)
    elif config.dataset == "celeba":
        dataset = torchvision.datasets.ImageFolder(config.dataset_path,
                                                   transform=transform)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.workers,
        drop_last=True,
    )

    model = nn.ModuleDict({
        "discriminator":
        Discriminator(config.image_size),
        "generator":
        Generator(config.image_size, config.latent_size),
    })
    model.to(DEVICE)
    if config.restore_path is not None:
        model.load_state_dict(torch.load(config.restore_path))

    discriminator_opt = torch.optim.Adam(model.discriminator.parameters(),
                                         lr=config.opt.lr,
                                         betas=config.opt.beta,
                                         eps=1e-8)
    generator_opt = torch.optim.Adam(model.generator.parameters(),
                                     lr=config.opt.lr,
                                     betas=config.opt.beta,
                                     eps=1e-8)

    noise_dist = torch.distributions.Normal(0, 1)

    writer = SummaryWriter(config.experiment_path)
    metrics = {
        "loss/discriminator": Mean(),
        "loss/generator": Mean(),
        "level": Last(),
        "alpha": Last(),
    }

    for epoch in range(1, config.epochs + 1):
        model.train()

        level, _ = compute_level(epoch - 1, config.epochs, 0, len(data_loader),
                                 config.image_size, config.grow_min_level)
        update_transform(int(4 * 2**level))

        for i, (real, _) in enumerate(
                tqdm(data_loader, desc="epoch {} training".format(epoch))):
            _, a = compute_level(
                epoch - 1,
                config.epochs,
                i,
                len(data_loader),
                config.image_size,
                config.grow_min_level,
            )

            real = real.to(DEVICE)

            # discriminator ############################################################################################
            discriminator_opt.zero_grad()

            # real
            scores = model.discriminator(real, level=level, a=a)
            loss = F.softplus(-scores)
            loss.mean().backward()
            loss_real = loss

            # fake
            noise = noise_dist.sample(
                (config.batch_size, config.latent_size)).to(DEVICE)
            fake = model.generator(noise, level=level, a=a)
            assert real.size() == fake.size()
            scores = model.discriminator(fake, level=level, a=a)
            loss = F.softplus(scores)
            loss.mean().backward()
            loss_fake = loss

            discriminator_opt.step()
            metrics["loss/discriminator"].update(
                (loss_real + loss_fake).data.cpu().numpy())

            # generator ################################################################################################
            generator_opt.zero_grad()

            # fake
            noise = noise_dist.sample(
                (config.batch_size, config.latent_size)).to(DEVICE)
            fake = model.generator(noise, level=level, a=a)
            assert real.size() == fake.size()
            scores = model.discriminator(fake, level=level, a=a)
            loss = F.softplus(-scores)
            loss.mean().backward()

            generator_opt.step()
            metrics["loss/generator"].update(loss.data.cpu().numpy())

            metrics["level"].update(level)
            metrics["alpha"].update(a)

        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image("real",
                         utils.make_grid((real + 1) / 2),
                         global_step=epoch)
        writer.add_image("fake",
                         utils.make_grid((fake + 1) / 2),
                         global_step=epoch)

        torch.save(
            model.state_dict(),
            os.path.join(config.experiment_path, "model_{}.pth".format(epoch)))