Пример #1
0
    Path('/'.join(args.check_point.split('/')[:-1])).mkdir(parents=True,
                                                           exist_ok=True)
    Path(args.logs_root).mkdir(parents=True, exist_ok=True)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)

    start_epoch = 0
    if args.continue_train:
        start_epoch = load_weights(state_dict_path=args.check_point,
                                   models=model,
                                   model_names='model',
                                   optimizers=optimizer,
                                   optimizer_names='optimizer',
                                   return_val='start_epoch')

    pbar = tqdm(range(start_epoch, args.epochs))
    for epoch in pbar:
        model.train()
        train_losses = []
        for x1, x2 in train_loader:
            x1, x2 = x1.to(device), x2.to(device)
            z1, z2, p1, p2 = model(x1, x2)
            if args.symmetric:
                loss = (model.module.cosine_loss(p1, z2) +
                        model.module.cosine_loss(p2, z1)) / 2
            else:
                loss = model.module.cosine_loss(p1, z2)
Пример #2
0
    lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D,
        lr_lambda=LambdaLR(args.n_epochs, args.starting_epoch, args.decay_epoch).step
    )

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    pool_A = ReplayBuffer()
    pool_B = ReplayBuffer()

    if args.continue_train:
        args.start_epoch = load_weights(state_dict_path=f"{args.checkpoint_dir}/{args.data_root.split('/')[-1]}.pth",
                                        models=[D_A, D_B, G_AB, G_BA],
                                        model_names=['D_A', 'D_B', 'G_AB', 'G_BA'],
                                        optimizers=[optimizer_G, optimizer_D],
                                        optimizer_names=['optimizer_G', 'optimizer_D'],
                                        return_val='start_epoch')
    pbar = tqdm(
            range(args.starting_epoch, args.n_epochs),
            total=(args.n_epochs - args.starting_epoch)
        )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    sampled_idx = get_random_ids(len(dataloader), args.sample_batches)
    for epoch in pbar:
        G_AB.train()
        G_BA.train()
        D_A.train()
        D_B.train()
        disc_A_losses, gen_A_losses, disc_B_losses, gen_B_losses = [], [], [], []
        (gen_ad_A_losses,
Пример #3
0
        _ = _display.start()

    if args.img_input:
        env.reset()
        env = PixelObservationWrapper(env)

    agent = Agent(env, args.alpha, args.beta, args.hidden_dims, args.tau,
                  args.batch_size, args.gamma, args.d, 0, args.max_size,
                  args.c * max_action, args.sigma * max_action,
                  args.one_device, args.log_dir, args.checkpoint_dir,
                  args.img_input, args.in_channels, args.order, args.depth,
                  args.multiplier, args.action_embed_dim, args.hidden_dim,
                  args.crop_dim, args.img_feature_dim)

    best_score = env.reward_range[0]
    load_weights(args.checkpoint_dir, [agent.actor], ['actor'])
    episodes = tqdm(range(args.n_episodes))

    for e in episodes:
        # resetting
        state = env.reset()
        if args.img_input:
            state_queue = deque([
                preprocess_img(state['pixels'], args.crop_dim)
                for _ in range(args.order)
            ],
                                maxlen=args.order)
            state = torch.cat(list(state_queue), 1).cpu().numpy()
        done, score = False, 0

        while not done:
Пример #4
0
def train():
    # creating dirs if needed
    Path(opt.checkpoint_dir).mkdir(parents=True, exist_ok=True)
    Path(opt.log_dir).mkdir(parents=True, exist_ok=True)
    if opt.save_local_samples:
        Path(opt.sample_dir).mkdir(parents=True, exist_ok=True)

    writer = SummaryWriter(opt.log_dir + f'/{int(datetime.now().timestamp()*1e6)}')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if opt.conditional:
        loader, num_classes = get_cifar_loader(opt.batch_size, opt.crop_size, opt.img_size)
    else:
        num_classes = 0
        loader = get_celeba_loaders(opt.data_path, opt.img_ext, opt.crop_size,
                                    opt.img_size, opt.batch_size, opt.download)

    G = torch.nn.DataParallel(Generator(opt.h_dim, opt.z_dim, opt.img_channels, opt.img_size, num_classes=num_classes), device_ids=opt.devices).to(device)
    D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size, num_classes=num_classes), device_ids=opt.devices).to(device)

    if opt.criterion == 'wasserstein-gp':
        criterion = Wasserstein_GP_Loss(opt.lambda_gp)
    elif opt.criterion == 'hinge':
        criterion = Hinge_loss()
    else:
        raise NotImplementedError('Please choose criterion [hinge, wasserstein-gp]')
    optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr_G, betas=opt.betas)
    optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_D, betas=opt.betas)

    # sample fixed z to see progress through training
    fixed_z = torch.randn(opt.sample_size, opt.z_dim).to(device)
    if opt.conditional:
        fixed_fake_labels = get_random_labels(num_classes, opt.sample_size, device)

    # if continue training, load weights, otherwise starting epoch=0
    if opt.continue_train:
        start_epoch = load_weights(state_dict_path=opt.checkpoint_dir,
                                   models=[D, G],
                                   model_names=['D', 'G'],
                                   optimizers=[optimizer_D, optimizer_G],
                                   optimizer_names=['optimizer_D', 'optimizer_G'],
                                   return_val='start_epoch')
    else:
        start_epoch = 0
    pbar = tqdm(range(start_epoch, opt.n_epochs))
    ckpt_iter = 0

    for epoch in pbar:
        d_losses, g_losses = [], []
        D.train()
        G.train()
        for batch_idx, data in enumerate(loader):
            # data prep
            if opt.conditional:
                reals, labels = [d.to(device) for d in data]
                fake_labels = get_random_labels(num_classes, reals.size(0), device)
            else:
                reals = data.to(device)
                fake_labels, labels = None, None
            z = torch.randn(reals.size(0), opt.z_dim).to(device)

            # forward generator
            optimizer_G.zero_grad()
            fakes = G(z, fake_labels)
            g_loss = criterion(fake_logits=D(fakes, fake_labels), mode='generator')

            # update gen
            g_loss.backward()
            optimizer_G.step()

            # forward discriminator
            optimizer_D.zero_grad()
            logits_fake = D(fakes.detach(), fake_labels)
            logits_real = D(reals, labels)

            # compute loss & update disc
            d_loss = criterion(fake_logits=logits_fake, real_logits=logits_real, mode='discriminator')

            # if wgangp, calculate gradient penalty and add to current d_loss
            if opt.criterion == 'wasserstein-gp':
                interpolates = criterion.get_interpolates(reals, fakes)
                interpolated_logits = D(interpolates, labels)
                grad_penalty = criterion.grad_penalty_loss(interpolates, interpolated_logits)
                d_loss = d_loss + grad_penalty

            d_loss.backward()
            optimizer_D.step()

            # logging
            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())
            pbar.set_postfix({
                'G Loss': g_loss.item(),
                'D Loss': d_loss.item(),
                'Batch ID': batch_idx})

            # tensorboard logging samples, not logging first iteration
            if batch_idx % opt.cpt_interval == 0:
                ckpt_iter += 1
                G.eval()
                # generate image from fixed noise vector
                with torch.no_grad():
                    if opt.conditional:
                        samples = G(fixed_z, fixed_fake_labels)
                    else:
                        samples = G(fixed_z)
                    samples = (samples + 1) / 2

                # save locally
                if opt.save_local_samples:
                    torchvision.utils.save_image(samples, f'{opt.sample_dir}/Interval_{ckpt_iter}.{opt.img_ext}')

                # save sample and loss to tensorboard
                writer.add_image('Generated Images', torchvision.utils.make_grid(samples), global_step=ckpt_iter)
                writer.add_scalars("Train Losses", {
                    "Discriminator Loss": sum(d_losses) / len(d_losses),
                    "Generator Loss": sum(g_losses) / len(g_losses)
                }, global_step=ckpt_iter)

                # resetting
                G.train()

        # printing loss and save weights
        tqdm.write(
            f'Epoch {epoch + 1}/{opt.n_epochs}, \
                Discriminator loss: {sum(d_losses) / len(d_losses):.3f}, \
                Generator Loss: {sum(g_losses) / len(g_losses):.3f}'
        )
        torch.save({
                'D': D.state_dict(),
                'G': G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'start_epoch': epoch + 1
            }, f"{opt.checkpoint_dir}/SA_GAN.pth")
Пример #5
0
    criterion = MoCoLoss(args.temperature)
    optimizer = torch.optim.SGD(f_q.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
        optimizer, lambda epoch: 0.1 if epoch in (120, 160) else 1)
    memo_bank = MemoryBank(f_k, device, momentum_loader, args.K)
    writer = SummaryWriter(args.logs_root +
                           f'/{int(datetime.now().timestamp()*1e6)}')

    start_epoch = 0
    if args.continue_train:
        start_epoch = load_weights(state_dict_path=args.check_point,
                                   models=[f_q, f_k],
                                   model_names=['f_q', 'f_k'],
                                   optimizers=[optimizer],
                                   optimizer_names=['optimizer'],
                                   return_val='start_epoch')

    pbar = tqdm(range(start_epoch, args.epochs))
    for epoch in pbar:
        train_losses = []
        f_q.train()
        f_k.train()
        for x1, x2 in train_loader:
            q1, q2 = f_q(x1), f_q(x2)
            with torch.no_grad():
                momentum_update(f_k, f_q, args.m)
                k1, k2 = f_k(x1), f_k(x2)
            loss = criterion(q1, k2, memo_bank) + criterion(q2, k1, memo_bank)
            optimizer.zero_grad()