def train_bottleneck_model(nb_epochs, batch_size):
    """Train bottleneck model"""
    # Load the training and validation bottleneck features
    train_data = load_np_array(cfg.bf_train_path)
    val_data = load_np_array(cfg.bf_val_path)

    # Get training and validation labels for bottleneck features
    # (we know the images are in sorted order)
    train_labels = []
    val_labels = []
    k = 0
    for class_name in cfg.classes:
        train_labels += [k] * len(
            os.listdir(osp.join(cfg.train_data_dir, class_name)))
        val_labels += [k] * len(
            os.listdir(osp.join(cfg.val_data_dir, class_name)))
        k += 1

    # Create custom model
    model = BottleneckModel.build(input_shape=train_data.shape[1:],
                                  nb_classes=cfg.nb_classes)

    # If multiclass, encode the labels to 1-to-K binary format
    if cfg.nb_classes != 2:
        train_labels = np_utils.to_categorical(train_labels, cfg.nb_classes)
        val_labels = np_utils.to_categorical(val_labels, cfg.nb_classes)

    # Compile model
    model.compile(loss='{}_crossentropy'.format(cfg.classmode),
                  optimizer=Adam(lr=5e-5),
                  metrics=['accuracy'])

    # Print model summary
    model.summary()

    # Save weights with best val loss
    model_checkpoint = ModelCheckpoint(cfg.model_weights_path,
                                       save_best_only=True,
                                       save_weights_only=True,
                                       monitor='val_loss')

    # Decay learning rate by half every 20 epochs
    decay = decay_lr(20, 0.5)

    # Start training
    history = model.fit(train_data,
                        train_labels,
                        nb_epoch=nb_epochs,
                        batch_size=batch_size,
                        validation_data=(val_data, val_labels),
                        callbacks=[model_checkpoint, decay])

    # Load best weights to get val data predictions
    model.load_weights(cfg.model_weights_path)

    # Get val data predictions
    val_pred_proba = model.predict(val_data)

    return model, history, val_pred_proba
def run(net,
        loader,
        optimizer,
        scheduler,
        tracker,
        train=False,
        has_answers=True,
        prefix='',
        epoch=0):
    """ Run an epoch over the given loader """
    assert not (train and not has_answers)
    if train:
        net.train()
        tracker_class, tracker_params = tracker.MovingMeanMonitor, {
            'momentum': 0.99
        }
    else:
        net.eval()
        tracker_class, tracker_params = tracker.MeanMonitor, {}
        answ = []
        idxs = []
        accs = []

    # set learning rate decay policy
    if epoch < len(config.gradual_warmup_steps
                   ) and config.schedule_method == 'warm_up':
        utils.set_lr(optimizer, config.gradual_warmup_steps[epoch])
        utils.print_lr(optimizer, prefix, epoch)
    elif (epoch in config.lr_decay_epochs
          ) and train and config.schedule_method == 'warm_up':
        utils.decay_lr(optimizer, config.lr_decay_rate)
        utils.print_lr(optimizer, prefix, epoch)
    else:
        utils.print_lr(optimizer, prefix, epoch)

    loader = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0)
    loss_tracker = tracker.track('{}_loss'.format(prefix),
                                 tracker_class(**tracker_params))
    acc_tracker = tracker.track('{}_acc'.format(prefix),
                                tracker_class(**tracker_params))

    for v, q, a, b, idx, v_mask, q_mask, q_len in loader:
        var_params = {
            'requires_grad': False,
        }
        v = Variable(v.cuda(), **var_params)
        q = Variable(q.cuda(), **var_params)
        a = Variable(a.cuda(), **var_params)
        b = Variable(b.cuda(), **var_params)
        q_len = Variable(q_len.cuda(), **var_params)
        v_mask = Variable(v_mask.cuda(), **var_params)
        q_mask = Variable(q_mask.cuda(), **var_params)

        out = net(v, b, q, v_mask, q_mask, q_len)
        if has_answers:
            answer = utils.process_answer(a)
            loss = utils.calculate_loss(answer, out, method=config.loss_method)
            acc = utils.batch_accuracy(out, answer).data.cpu()

        if train:
            optimizer.zero_grad()
            loss.backward()
            # print gradient
            if config.print_gradient:
                utils.print_grad([(n, p) for n, p in net.named_parameters()
                                  if p.grad is not None])
            # clip gradient
            clip_grad_norm_(net.parameters(), config.clip_value)
            optimizer.step()
            if (config.schedule_method == 'batch_decay'):
                scheduler.step()
        else:
            # store information about evaluation of this minibatch
            _, answer = out.data.cpu().max(dim=1)
            answ.append(answer.view(-1))
            if has_answers:
                accs.append(acc.view(-1))
            idxs.append(idx.view(-1).clone())

        if has_answers:
            loss_tracker.append(loss.item())
            acc_tracker.append(acc.mean())
            fmt = '{:.4f}'.format
            loader.set_postfix(loss=fmt(loss_tracker.mean.value),
                               acc=fmt(acc_tracker.mean.value))

    if not train:
        answ = list(torch.cat(answ, dim=0))
        if has_answers:
            accs = list(torch.cat(accs, dim=0))
        else:
            accs = []
        idxs = list(torch.cat(idxs, dim=0))
        #print('{} E{:03d}:'.format(prefix, epoch), ' Total num: ', len(accs))
        #print('{} E{:03d}:'.format(prefix, epoch), ' Average Score: ', float(sum(accs) / len(accs)))
        return answ, accs, idxs
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("-conf", type=str)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    debug = args.debug
    config = configparser.ConfigParser()
    config.read(args.conf)

    log_path = config["log"]["log_path"]
    log_step = int(config["log"]["log_step"])
    log_dir = os.path.dirname(log_path)
    os.makedirs(log_dir, exist_ok=True)

    save_prefix = config["save"]["save_prefix"]
    save_format = save_prefix + ".network.epoch{}"
    optimizer_save_format = save_prefix + ".optimizer.epoch{}"
    save_step = int(config["save"]["save_step"])
    save_dir = os.path.dirname(save_prefix)
    os.makedirs(save_dir, exist_ok=True)

    num_epochs = int(config["train"]["num_epochs"])
    batch_size = int(config["train"]["batch_size"])
    decay_start_epoch = int(config["train"]["decay_start_epoch"])
    decay_rate = float(config["train"]["decay_rate"])
    vocab_size = int(config["vocab"]["vocab_size"])
    ls_prob = float(config["train"]["ls_prob"])
    distill_weight = float(config["distill"]["distill_weight"])

    if debug:
        logging.basicConfig(format="%(asctime)s %(message)s",
                            level=logging.INFO)  # to stdout
    else:
        logging.basicConfig(filename=log_path,
                            format="%(asctime)s %(message)s",
                            level=logging.DEBUG)

    model = AttnModel(args.conf)
    model.apply(init_weight)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), weight_decay=1e-5)

    dataset = SpeechDataset(args.conf)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            collate_fn=collate_fn,
                            num_workers=2,
                            pin_memory=True)
    num_steps = len(dataloader)

    for epoch in range(num_epochs):
        loss_sum = 0

        for step, data in enumerate(dataloader):
            loss_step = train_step(model, optimizer, data, vocab_size, ls_prob,
                                   distill_weight)
            loss_sum += loss_step

            if (step + 1) % log_step == 0:
                logging.info(
                    "epoch = {:>2} step = {:>6} / {:>6} loss = {:.3f}".format(
                        epoch + 1, step + 1, num_steps, loss_sum / log_step))
                loss_sum = 0

        if epoch == 0 or (epoch + 1) % save_step == 0:
            save_path = save_format.format(epoch + 1)
            torch.save(model.state_dict(), save_path)
            optimizer_save_path = optimizer_save_format.format(epoch + 1)
            torch.save(optimizer.state_dict(), optimizer_save_path)
            logging.info("model saved to: {}".format(save_path))
            logging.info("optimizer saved to: {}".format(optimizer_save_path))
        update_epoch(model, epoch + 1)
        decay_lr(optimizer, epoch + 1, decay_start_epoch, decay_rate)
Exemplo n.º 4
0
def run(net,
        loader,
        optimizer,
        scheduler,
        tracker,
        train=False,
        prefix='',
        epoch=0):
    """ Run an epoch over the given loader """
    if train:
        net.train()
        # tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99}
    else:
        net.eval()

    tracker_class, tracker_params = tracker.MeanMonitor, {}

    # set learning rate decay policy
    if epoch < len(config.gradual_warmup_steps
                   ) and config.schedule_method == 'warm_up':
        utils.set_lr(optimizer, config.gradual_warmup_steps[epoch])

    elif (epoch in config.lr_decay_epochs
          ) and train and config.schedule_method == 'warm_up':
        utils.decay_lr(optimizer, config.lr_decay_rate)

    utils.print_lr(optimizer, prefix, epoch)

    loader = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0)
    loss_tracker = tracker.track('{}_loss'.format(prefix),
                                 tracker_class(**tracker_params))
    acc_tracker = tracker.track('{}_acc'.format(prefix),
                                tracker_class(**tracker_params))

    for v, q, a, b, idx, v_mask, q_mask, q_len in loader:
        var_params = {
            'requires_grad': False,
        }
        v = Variable(v.cuda(), **var_params)
        q = Variable(q.cuda(), **var_params)
        a = Variable(a.cuda(), **var_params)
        b = Variable(b.cuda(), **var_params)
        q_len = Variable(q_len.cuda(), **var_params)
        v_mask = Variable(v_mask.cuda(), **var_params)
        q_mask = Variable(q_mask.cuda(), **var_params)

        out = net(v, b, q, v_mask, q_mask, q_len)

        answer = utils.process_answer(a)
        loss = utils.calculate_loss(answer, out, method=config.loss_method)
        acc = utils.batch_accuracy(out, answer).data.cpu()

        if train:
            optimizer.zero_grad()
            loss.backward()
            # clip gradient
            clip_grad_norm_(net.parameters(), config.clip_value)
            optimizer.step()
            if config.schedule_method == 'batch_decay':
                scheduler.step()

        loss_tracker.append(loss.item())
        acc_tracker.append(acc.mean())
        fmt = '{:.4f}'.format
        loader.set_postfix(loss=fmt(loss_tracker.mean.value),
                           acc=fmt(acc_tracker.mean.value))

    return acc_tracker.mean.value, loss_tracker.mean.value
Exemplo n.º 5
0
    def train(self):
        loss = {}
        nrow = min(int(np.sqrt(self.batch_size)), 8)
        n_samples = nrow * nrow
        iter_per_epoch = len(self.train_loader.dataset) // self.batch_size
        max_iteration = self.num_epoch * iter_per_epoch
        lambda_l1 = 0.2
        print('Start training...')
        for epoch in tqdm(range(self.resume_epoch, self.num_epoch)):
            for i, (x_real, noise,
                    label) in enumerate(tqdm(self.train_loader)):

                # lr decay
                if epoch * iter_per_epoch + i >= self.lr_decay_start:
                    utils.decay_lr(self.g_optimizer, max_iteration,
                                   self.lr_decay_start, self.g_lr)
                    utils.decay_lr(self.d_optimizer, max_iteration,
                                   self.lr_decay_start, self.d_lr)
                    if i % 1000 == 0:
                        print('d_lr / g_lr is updated to {:.8f} / {:.8f} !'.
                              format(self.d_optimizer.param_groups[0]['lr'],
                                     self.g_optimizer.param_groups[0]['lr']))

                x_real = x_real.to(self.device)
                noise = noise.to(self.device)
                label = label.to(self.device)
                #'''
                # =================================================================================== #
                #							  1. Train the discriminator							  #
                # =================================================================================== #
                for param in self.D.parameters():
                    param.requires_grad = True

                dis_real, real_list = self.D(x_real, label)
                real_list = [h.detach() for h in real_list]

                x_fake = self.G(noise, label).detach()
                dis_fake, _ = self.D(x_fake, label)

                d_loss_real, d_loss_fake = self.dis_hinge(dis_real, dis_fake)

                # sample
                try:
                    x_real2, label2 = next(real_iter)
                except:
                    real_iter = iter(self.real_loader)
                    x_real2, label2 = next(real_iter)
                x_real2 = x_real2.to(self.device)
                label2 = label2.to(self.device)

                noise2 = torch.FloatTensor(utils.truncated_normal(self.batch_size*self.z_dim)) \
                      .view(self.batch_size, self.z_dim).to(self.device)
                #				 noise2 = torch.randn(self.batch_size, self.z_dim).to(self.device)
                dis_real2, _ = self.D(x_real2, label2)
                x_fake2 = self.G(noise2, label2).detach()
                dis_fake2, _ = self.D(x_fake2, label2)
                d_loss_real2, d_loss_fake2 = self.dis_hinge(
                    dis_real2, dis_fake2)

                # Backward and optimize.
                d_loss = d_loss_real + d_loss_fake + 0.2 * (d_loss_real2 +
                                                            d_loss_fake2)

                self.d_optimizer.zero_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging.
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_real2'] = d_loss_real2.item()
                loss['D/loss_fake2'] = d_loss_fake2.item()

                # =================================================================================== #
                #								2. Train the generator								  #
                # =================================================================================== #
                #'''

                x_fake = self.G(noise, label)

                for param in self.D.parameters():
                    param.requires_grad = False

                dis_fake, fake_list = self.D(x_fake, label)

                g_loss_feat = self.KDLoss(real_list, fake_list)
                g_loss_pix = F.l1_loss(x_fake, x_real)
                g_loss = g_loss_feat + lambda_l1 * g_loss_pix
                loss['G/loss_ft'] = g_loss_feat.item()
                loss['G/loss_l1'] = g_loss_pix.item()

                if (i + 1) % self.n_critic == 0:
                    dis_fake, _ = self.D(x_fake, label)
                    g_loss_fake = self.gen_hinge(dis_fake)

                    g_loss += self.lambda_gan * g_loss_fake

                    # sample
                    noise2 = torch.FloatTensor(utils.truncated_normal(self.batch_size*self.z_dim)) \
                         .view(self.batch_size, self.z_dim).to(self.device)
                    #					 noise2 = torch.randn(self.batch_size, self.z_dim).to(self.device)
                    x_fake2 = self.G(noise2, label2)
                    dis_fake2, _ = self.D(x_fake2, label2)
                    g_loss_fake2 = self.gen_hinge(dis_fake2)
                    g_loss += 0.2 * self.lambda_gan * g_loss_fake2

                    loss['G/loss_fake'] = g_loss_fake.item()
                    loss['G/loss_fake2'] = g_loss_fake2.item()

                self.g_optimizer.zero_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # =================================================================================== #
                #								  3. Miscellaneous									  #
                # =================================================================================== #

                # Print out training information.
                if (i + 1) % self.log_step == 0:
                    log = "[{}/{}]".format(epoch, i)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i + 1)

            if epoch == 0 or (epoch + 1) % self.sample_step == 0:
                with torch.no_grad():
                    """
					# randomly sampled noise
					noise = torch.FloatTensor(utils.truncated_normal(n_samples*self.z_dim)) \
										.view(n_samples, self.z_dim).to(self.device)
					label = label[:nrow].repeat(nrow)

					#label = np.random.choice(1000, nrow, replace=False)
					#label = torch.tensor(label).repeat(10).to(self.device)
					x_sample = self.G(noise, label)
					sample_path = os.path.join(self.sample_dir, '{}-sample.png'.format(epoch+1))
					save_image(utils.denorm(x_sample.cpu()), sample_path, nrow=nrow, padding=0)
					"""
                    # recons
                    n = min(x_real.size(0), 8)
                    comparison = torch.cat([x_real[:n], x_fake[:n]])
                    sample_path = os.path.join(
                        self.sample_dir, '{}-train.png'.format(epoch + 1))
                    save_image(utils.denorm(comparison.cpu()), sample_path)
                    print('Save fake images into {}...'.format(sample_path))

                    # noise2
                    comparison = torch.cat([x_real2[:n], x_fake2[:n]])
                    sample_path = os.path.join(
                        self.sample_dir, '{}-random.png'.format(epoch + 1))
                    save_image(utils.denorm(comparison.cpu()), sample_path)
                    print('Save fake images into {}...'.format(sample_path))

                    # noise sampled from BigGAN's test set
                    try:
                        x_real, noise, label = next(test_iter)
                    except:
                        test_iter = iter(self.test_loader)
                        x_real, noise, label = next(test_iter)
                    noise = noise.to(self.device)
                    label = label.to(self.device)

                    x_fake = self.G(noise, label).detach().cpu()
                    n = min(x_real.size(0), 8)
                    comparison = torch.cat([x_real[:n], x_fake[:n]])
                    sample_path = os.path.join(self.sample_dir,
                                               '{}-test.png'.format(epoch + 1))
                    save_image(utils.denorm(comparison.cpu()), sample_path)
                    print('Save fake images into {}...'.format(sample_path))

            lambda_l1 = max(0.00, lambda_l1 - 0.01)
            # Save model checkpoints.
            if (epoch + 1) % self.model_save_step == 0:
                utils.save_model(self.model_save_dir, epoch + 1, self.G,
                                 self.D, self.g_optimizer, self.d_optimizer)
Exemplo n.º 6
0
def train():
    dir_name, save_path = getSavePath()
    if args.norm == 'sn':
        netG, netD = networks.getGD_SN(args.structure, args.dataset,
                                       args.Gnum_features, args.Dnum_features)
    elif args.norm == 'bn':
        netG, netD = networks.getGD_batchnorm(args.structure,
                                              args.dataset,
                                              args.Gnum_features,
                                              args.Dnum_features,
                                              dim_z=args.input_dim)

    if args.ema_trick:
        ema_netG_9999 = copy.deepcopy(netG)

    netG.cuda()
    netD.cuda()
    g_optimizer = torch.optim.Adam(netG.parameters(),
                                   lr=args.g_lr,
                                   betas=(args.beta1, args.beta2))
    d_optimizer = torch.optim.Adam(netD.parameters(),
                                   lr=args.d_lr,
                                   betas=(args.beta1, args.beta2))

    g_losses, d_losses = [], []
    grad_normD, grad_normG = [], []

    loader = datasets.getDataLoader(args.dataset,
                                    args.image_size,
                                    batch_size=args.batch_size,
                                    shuffle=not args.fixz)
    data_iter = iter(loader)

    data_num = len(loader.dataset)
    zs = TensorDataset(torch.randn(data_num, args.input_dim))
    zloader = DataLoader(zs,
                         batch_size=args.batch_size,
                         shuffle=not args.fixz,
                         num_workers=0,
                         drop_last=True)
    z_iter = iter(zloader)

    for i in range(1, args.num_iters + 1):
        if i >= args.lr_decay_start:
            utils.decay_lr(g_optimizer, args.num_iters, args.lr_decay_start,
                           args.g_lr)
            utils.decay_lr(d_optimizer, args.num_iters, args.lr_decay_start,
                           args.d_lr)

        if i == 1:
            torch.save(netG.state_dict(), save_path + 'G_epoch0.pth')
            torch.save(netD.state_dict(), save_path + 'D_epoch0.pth')

        # G-step
        for _ in range(args.g_freq):
            if args.fixz:
                try:
                    z = next(z_iter)[0].cuda()
                except:
                    z_iter = iter(zloader)
                    z = next(z_iter)[0].cuda()
            else:
                z = torch.randn(args.batch_size, args.input_dim, device=device)
            g_optimizer.zero_grad()
            x_hat = netG(z)
            y_hat = netD(x_hat)
            g_loss = get_gloss(args.losstype, y_hat)
            g_losses.append(g_loss.item())
            g_loss.backward()
            g_optimizer.step()
            grad_normG.append(utils.getGradNorm(netG))

            if args.ema_trick:
                moving_average.soft_copy_param(ema_netG_9999, netG, 0.9999)

        for _ in range(args.d_freq):
            try:
                x = next(data_iter)[0].cuda().float()
            except StopIteration:
                data_iter = iter(loader)
                x = next(data_iter)[0].cuda().float()
            if args.fixz:
                try:
                    z = next(z_iter)[0].cuda()
                except:
                    z_iter = iter(zloader)
                    z = next(z_iter)[0].cuda()
            else:
                z = torch.randn(args.batch_size, args.input_dim, device=device)

            d_optimizer.zero_grad()
            x_hat = netG(z).detach()
            y_hat = netD(x_hat)
            y = netD(x)
            d_loss = get_dloss(args.losstype, y_hat, y)
            d_losses.append(d_loss.item())

            d_loss.backward()
            d_optimizer.step()
            grad_normD.append(utils.getGradNorm(netD))

        if i % args.print_freq == 0:
            print('Iteration: {}; G-Loss: {}; D-Loss: {};'.format(
                i, g_loss, d_loss))

        if i == 1:
            save_image((x / 2. + 0.5)[:36], os.path.join(dir_name, 'real.png'))

        if i == 1 or i % args.plot_freq == 0:
            plot_x = netG(torch.randn(36, args.input_dim, device=device)).data
            plot_x = plot_x / 2. + 0.5
            save_image(
                plot_x,
                os.path.join(dir_name, 'fake_images-{}.png'.format(i + 1)))
            utils.plot_losses(g_losses, d_losses, grad_normG, grad_normD,
                              dir_name)
            utils.saveproj(y.cpu(), y_hat.cpu(), i, save_path)

        if i % args.save_freq == 0:
            torch.save(netG.state_dict(),
                       save_path + 'G_epoch{}.pth'.format(i))
            torch.save(netD.state_dict(),
                       save_path + 'D_epoch{}.pth'.format(i))
            if args.ema_trick:
                torch.save(ema_netG_9999.state_dict(),
                           save_path + 'emaG0.9999_epoch{}.pth'.format(i))
Exemplo n.º 7
0
def SWAG(model,
         dataloader,
         optimizer,
         criterion,
         epochs=3000,
         print_freq=1000,
         swag_start=2000,
         M=1e5,
         lr_ratio=1,
         verbose=False):
    '''Implementation of Stochastic Weight Averaging'''
    model.train()  # prep model layers for training

    # initialize first moment as vectors w/ length = num model parameters
    num_params = sum(param.numel() for param in model.parameters())
    first_moment = torch.zeros(num_params)

    # initialize deviation matrix 'A'
    A = torch.empty(0, num_params, dtype=torch.float32)
    lr_init = optimizer.defaults['lr']

    n_iterates = 0
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0

        # Implementation of learning rate decay from paper
        epoch_ratio = (epoch + 1) / swag_start
        lr = decay_lr(optimizer,
                      epoch_ratio,
                      lr_init=lr_init,
                      lr_ratio=lr_ratio)

        for inputs, labels in dataloader:
            optimizer.zero_grad()  # clear gradients

            preds = model(inputs)  # perform a forward pass
            loss = criterion(preds, labels)  # compute the loss

            loss.backward()  # backpropagate
            optimizer.step()  # update the weights

            epoch_loss += loss.data.item() * inputs.shape[0]

        # Print output
        if (epoch % print_freq == 0 or epoch == epochs - 1) and verbose:
            print('Epoch %d | LR: %g | Loss: %.4f' % (epoch, lr, epoch_loss))

        # Average gradient weights
        if epoch > swag_start:
            # obtain a flattened vector of weights
            weights_list = [param.detach() for param in model.parameters()]
            w = torch.cat([w.contiguous().view(-1, 1)
                           for w in weights_list]).view(-1)

            # update the first moment
            first_moment = (n_iterates * first_moment + w) / (n_iterates + 1)

            # update 'a' matrix  (following their code implementation)
            a = w - first_moment
            A = torch.cat((A, a.view(1, -1)), dim=0)

            # only store the last 'M' deviation vectors if memory limited
            if A.shape[1] > M:
                A = A[1:, :]

            n_iterates += 1

    return first_moment.double(), A.numpy()
Exemplo n.º 8
0
def train():
    dir_name, save_path = getSavePath()
    netG, netD = networks.getGD_SN(args.structure, args.dataset, args.image_size, args.num_features, 
                                    dim_z=args.input_dim, bottleneck=args.bottleneck)

    if args.ema_trick:
        ema_netG_9999 = copy.deepcopy(netG)

    if args.reload > 0:
        netG.load_state_dict(torch.load(save_path + 'G_epoch{}.pth'.format(args.reload)))
        netD.load_state_dict(torch.load(save_path + 'D_epoch{}.pth'.format(args.reload)))
        if args.ema_trick:
            ema_netG_9999.load_state_dict(
                torch.load(save_path + 'emaG0.9999_epoch{}.pth'.format(args.reload), map_location=torch.device('cpu')))

    netG.cuda()
    netD.cuda()
    
    g_optimizer = torch.optim.Adam(netG.parameters(), lr=args.g_lr, betas=(args.beta1, args.beta2))
    d_optimizer = torch.optim.Adam(netD.parameters(), lr=args.d_lr, betas=(args.beta1, args.beta2))

    g_losses, d_losses = [], []
    grad_normD, grad_normG = [], []

    loader = datasets.getDataLoader(args.dataset, args.image_size, batch_size=args.batch_size)
    data_iter = iter(loader)

    for i in range(1, args.num_iters+1):
        if i >= args.lr_decay_start:
            utils.decay_lr(g_optimizer, args.num_iters, args.lr_decay_start, args.g_lr)
            utils.decay_lr(d_optimizer, args.num_iters, args.lr_decay_start, args.d_lr)
        if i <= args.reload:
            continue
        if i == 1:
            torch.save(netG.state_dict(), save_path + 'G_epoch0.pth')
            torch.save(netD.state_dict(), save_path + 'D_epoch0.pth')
        # G-step
        for _ in range(args.g_freq):
            try:
                x = next(data_iter)[0].cuda().float()
            except StopIteration:
                data_iter = iter(loader)
                x = next(data_iter)[0].cuda().float()
            z = torch.randn(args.batch_size, args.input_dim, device=device)
            g_optimizer.zero_grad()
            x_hat = netG(z)
            y_hat = netD(x_hat)
            y = netD(x)
            g_loss = get_gloss(args.losstype, y_hat, y)
            g_losses.append(g_loss.item())
            g_loss.backward()
            g_optimizer.step()
            grad_normG.append(utils.getGradNorm(netG))

            if args.ema_trick:
                utils.soft_copy_param(ema_netG_9999, netG, 0.9999)

        for _ in range(args.d_freq):
            try:
                x = next(data_iter)[0].cuda().float()
            except StopIteration:
                data_iter = iter(loader)
                x = next(data_iter)[0].cuda().float()
            z = torch.randn(args.batch_size, args.input_dim, device=device)

            d_optimizer.zero_grad()
            x_hat = netG(z).detach()
            y_hat = netD(x_hat)
            y = netD(x)
            d_loss = get_dloss(args.losstype, y_hat, y)
            d_losses.append(d_loss.item())
            d_loss.backward()
            d_optimizer.step()
            grad_normD.append(utils.getGradNorm(netD))
            netD.proj.weight.data = F.normalize(netD.proj.weight.data, dim=1)

        if i % args.print_freq == 0:
            print('Iteration: {}; G-Loss: {}; D-Loss: {};'.format(i, g_loss, d_loss))

        if i == 1:
            save_image((x / 2. + 0.5), os.path.join(dir_name, 'real.pdf'))

        if i == 1 or i % args.plot_freq == 0:
            plot_x = netG(torch.randn(args.batch_size, args.input_dim, device=device)).data
            plot_x = plot_x / 2. + 0.5
            save_image(plot_x, os.path.join(dir_name, 'fake_images-{}.pdf'.format(i + 1)))
            utils.plot_losses(g_losses, d_losses, grad_normG, grad_normD, dir_name)

        if i % args.save_freq == 0:
            torch.save(netG.state_dict(), save_path + 'G_epoch{}.pth'.format(i))
            torch.save(netD.state_dict(), save_path + 'D_epoch{}.pth'.format(i))
            if args.ema_trick:
                torch.save(ema_netG_9999.state_dict(), save_path + 'emaG0.9999_epoch{}.pth'.format(i))
Exemplo n.º 9
0
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
            'Batch size must be divisible by number of GPUs')

        bs_l = FLAGS.batch_size
        num_iter_per_epoch = int(FLAGS.num_train_l / bs_l)
        max_steps = int(FLAGS.num_epochs * num_iter_per_epoch)
        num_classes = FLAGS.num_classes

        global_step = slim.create_global_step()
        lr = tf.placeholder(tf.float32, shape=[], name="learning_rate")
        opt = tf.train.MomentumOptimizer(learning_rate=lr,
                                         momentum=0.9,
                                         use_nesterov=True)

        images, labels = utils.prepare_traindata(FLAGS.dataset_dir_l,
                                                 int(bs_l))
        images_splits = tf.split(images, FLAGS.num_gpus, 0)
        labels_splits = tf.split(labels, FLAGS.num_gpus, 0)

        tower_grads = []
        top_1_op = []
        reuse_variables = None
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (network.TOWER_NAME, i)) as scope:
                    with slim.arg_scope(slim.get_model_variables(scope=scope),
                                        device='/cpu:0'):
                        loss, logits = \
                            _build_training_graph(images_splits[i], labels_splits[i], num_classes, reuse_variables)
                        top_1_op.append(
                            tf.nn.in_top_k(logits, labels_splits[i], 1))

                    reuse_variables = True
                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                  scope)
                    batchnorm_updates = tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS, scope)
                    grads = opt.compute_gradients(loss)
                    tower_grads.append(grads)

        grads = network.average_gradients(tower_grads)
        gradient_op = opt.apply_gradients(grads, global_step=global_step)

        var_averages = tf.train.ExponentialMovingAverage(
            FLAGS.ema_decay, global_step)
        var_op = var_averages.apply(tf.trainable_variables() +
                                    tf.moving_average_variables())

        batchnorm_op = tf.group(*batchnorm_updates)
        train_op = tf.group(gradient_op, var_op, batchnorm_op)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
        summary_op = tf.summary.merge(summaries)
        init_op = tf.global_variables_initializer()

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        if FLAGS.gpu_memory:
            config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory
        sess = tf.Session(config=config)

        boundaries, values = utils.config_lr(max_steps)
        sess.run([init_op], feed_dict={lr: values[0]})

        tf.train.start_queue_runners(sess=sess)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        iter_count = epoch = sum_loss = sum_top_1 = 0
        start = time.time()

        for step in range(max_steps):

            decayed_lr = utils.decay_lr(step, boundaries, values, max_steps)
            _, loss_value, top_1_value = \
                sess.run([train_op, loss, top_1_op], feed_dict={lr: decayed_lr})

            sum_loss += loss_value
            top_1_value = np.sum(top_1_value) / bs_l
            sum_top_1 += top_1_value
            iter_count += 1

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % num_iter_per_epoch == 0:
                end = time.time()
                sum_loss = sum_loss / num_iter_per_epoch
                sum_top_1 = min(sum_top_1 / num_iter_per_epoch, 1.0)
                time_per_iter = float(end - start) / iter_count
                format_str = (
                    'epoch %d, L = %.2f, top_1 = %.2f, lr = %.4f (time_per_iter: %.4f s)'
                )
                print(format_str % (epoch, sum_loss, sum_top_1 * 100,
                                    decayed_lr, time_per_iter))
                epoch += 1
                sum_loss = sum_top_1 = 0

            if step % 100 == 0:
                summary_str = sess.run(summary_op, feed_dict={lr: decayed_lr})
                summary_writer.add_summary(summary_str, step)

            if (step + 1) == max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=epoch)