Ejemplo n.º 1
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
                
            optimizer.step()

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))

            scheduler.step(global_step)
            global_step += x.size(0)
Ejemplo n.º 2
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, cond_x in trainloader:
            x , cond_x = x.to(device), cond_x.to(device)
            optimizer.zero_grad()
            z, sldj = net(x, cond_x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
    
    print('Saving...')
    state = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, 'savemodel/cINN/checkpoint_' + str(epoch) + '.tar')
Ejemplo n.º 3
0
def plugin_estimator_training_loop(real_nvp_model, dataloader, learning_rate,
                                   optim, device, total_iters,
                                   checkpoint_intervals, batchsize, algorithm,
                                   weight_decay, max_grad_norm, save_dir,
                                   save_suffix):
    """
    Function to train the RealNVP model using
    a plugin mean estimation algorithm for total_iters with learning_rate
    """
    param_groups = util.get_param_groups(real_nvp_model,
                                         weight_decay,
                                         norm_suffix='weight_g')
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)
    optimizer = optimizer_cons(param_groups)

    loss_fn = RealNVPLoss()
    flag = False
    iteration = 0
    while not flag:
        for x, _ in dataloader:
            # Update iteration counter
            iteration += 1

            x = x.to(device)
            z, sldj = real_nvp_model(x, reverse=False)
            unaggregated_loss = loss_fn(z, sldj, aggregate=False)
            if algorithm.__name__ == 'mean':
                agg_loss = unaggregated_loss.mean()
                agg_loss.backward()
            else:
                # First sample gradients
                sgradients = utils.gradient_sampler(unaggregated_loss,
                                                    real_nvp_model)
                # Then get the estimate with the mean estimation algorithm
                stoc_grad = algorithm(sgradients)
                # Perform the update of .grad attributes
                with torch.no_grad():
                    utils.update_grad_attributes(
                        real_nvp_model.parameters(),
                        torch.as_tensor(stoc_grad, device=device))
            # Clip gradient if required
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            # Perform the update
            optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    real_nvp_model.state_dict(),
                    f"{save_dir}/real_nvp_{algorithm.__name__}_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return real_nvp_model
Ejemplo n.º 4
0
def train_single_step(net, x, device, optimizer, loss_fn, max_grad_norm):
    net.train()
    x = x.to(device)
    optimizer.zero_grad()
    z, sldj = net(x, reverse=False)
    loss = loss_fn(z, sldj)
    loss.backward()
    if max_grad_norm > 0:
        util.clip_grad_norm(optimizer, max_grad_norm)
    optimizer.step()
Ejemplo n.º 5
0
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))
Ejemplo n.º 6
0
    def train_iter(self):
        """Run a training iteration (forward/backward) on a single batch.
        Important: Call `set_inputs` prior to each call to this function.
        """
        # Forward
        self.forward()

        # Backprop the generators
        self.opt_g.zero_grad()
        self.backward_g()
        util.clip_grad_norm(self.opt_g, self.max_grad_norm)
        self.opt_g.step()

        # Backprop the discriminators
        self.opt_d.zero_grad()
        self.backward_d()
        util.clip_grad_norm(self.opt_d, self.max_grad_norm)
        self.opt_d.step()
Ejemplo n.º 7
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    correct_class = 0
    correct_domain = 0

    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, y, d, yd in trainloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            optimizer.zero_grad()
            z1, z2 = net(x)
            loss2 = loss_fn(z2, d.argmax(dim=1))
            loss1 = loss_fn(z1, y.argmax(dim=1))
            loss_meter.update(loss2.item(), x.size(0))
            loss_meter.update(loss1.item(), x.size(0))
            loss = loss1 + loss2
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
            values_class, pred_class = torch.max(z1, 1)
            values_domain, pred_domain = torch.max(z2, 1)

            correct_class += pred_class.eq(y.argmax(dim=1)).sum().item()
            correct_domain += pred_domain.eq(d.argmax(dim=1)).sum().item()
    accuracy_class = correct_class * 100. / len(trainloader.dataset)
    accuracy_domain = correct_domain * 100. / len(trainloader.dataset)

    print('train accuracy class', accuracy_class)
    print('train accuracy domain', accuracy_domain)

    return accuracy_class, accuracy_domain
Ejemplo n.º 8
0
def train(epoch,
          net,
          trainloader,
          device,
          optimizer,
          loss_fn,
          max_grad_norm,
          base_path,
          save=False):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meters = [util.AverageMeter() for _ in range(3)]
    logvars = []
    output_vars = []
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x in trainloader:
            if len(x) == 2 and type(x) is list:
                x = x[0]
            x = x.to(device)
            optimizer.zero_grad()
            x_hat, mu, logvar, output_var = net(x)
            loss, reconstruction_loss, kl_loss = loss_fn(
                x, x_hat, mu, logvar, output_var)
            loss_meters[0].update(loss.item(), x.size(0))
            loss_meters[1].update(reconstruction_loss.item(), x.size(0))
            loss_meters[2].update(kl_loss.item(), x.size(0))
            loss.backward()
            util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            progress_bar.set_postfix(loss=loss_meters[0].avg,
                                     rc_loss=loss_meters[1].avg,
                                     kl_loss=loss_meters[2].avg)
            progress_bar.update(x.size(0))
            logvars.append(logvar.unsqueeze(0))
            output_vars.append(output_var.unsqueeze(0))
    if save:
        logvarfile = 'logvar' + str(epoch) + '.pt'
        output_varfile = 'output_var' + str(epoch) + '.pt'
        torch.save(logvars, base_path / logvarfile)
        torch.save(output_vars, base_path / output_varfile)
Ejemplo n.º 9
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm, mode):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    correct = 0

    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, y, d, yd in trainloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            optimizer.zero_grad()
            z = net(x)
            if mode == 'domain':
                loss = loss_fn(z, d.argmax(dim=1))
            else:
                loss = loss_fn(z, y.argmax(dim=1))
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
            values, pred = torch.max(z, 1)

            if mode == 'label':
                correct += pred.eq(y.argmax(dim=1)).sum().item()
            else:
                correct += pred.eq(d.argmax(dim=1)).sum().item()
    accuracy = correct * 100. / len(trainloader.dataset)

    print('train accuracy', accuracy)

    return accuracy
Ejemplo n.º 10
0
def streaming_approx_training_loop(real_nvp_model, dataloader, learning_rate,
                                   optim, device, total_iters,
                                   checkpoint_intervals, alpha, batchsize,
                                   n_discard, weight_decay, max_grad_norm,
                                   save_dir, save_suffix):
    """
    Function to train the RealNVP model using
    the streaming rank-1 approximation with algorithm for total_iters
    with optimizer optim
    """
    param_groups = util.get_param_groups(real_nvp_model,
                                         weight_decay,
                                         norm_suffix='weight_g')
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)
    optimizer = optimizer_cons(param_groups)

    loss_fn = RealNVPLoss()
    flag = False
    iteration = 0
    top_eigvec, top_eigval, running_mean = None, None, None

    real_nvp_model.train()
    while not flag:
        for x, _ in dataloader:
            # Update iteration counter
            iteration += 1

            x = x.to(device)
            z, sldj = real_nvp_model(x, reverse=False)
            unaggregated_loss = loss_fn(z, sldj, aggregate=False)
            # First sample gradients
            sgradients = utils.gradient_sampler(unaggregated_loss,
                                                real_nvp_model)
            # Then get the estimate with the previously computed direction
            stoc_grad, top_eigvec, top_eigval, running_mean = streaming_update_algorithm(
                sgradients,
                n_discard=n_discard,
                top_v=top_eigvec,
                top_lambda=top_eigval,
                old_mean=running_mean,
                alpha=alpha)
            # Perform the update of .grad attributes
            with torch.no_grad():
                utils.update_grad_attributes(
                    real_nvp_model.parameters(),
                    torch.as_tensor(stoc_grad, device=device))
            # Clip gradient if required
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            # Perform the update
            optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    real_nvp_model.state_dict(),
                    f"{save_dir}/real_nvp_streaming_approx_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return real_nvp_model
Ejemplo n.º 11
0
def train(model,
          embedder,
          optimizer,
          scheduler,
          train_loader,
          val_loader,
          opt,
          writer,
          device=None):
    print("TRAINING STARTS")
    global global_step
    for epoch in range(opt.n_epochs):
        print("[Epoch %d/%d]" % (epoch + 1, opt.n_epochs))
        model = model.train()
        loss_to_log = 0.0
        loss_fn = util.NLLLoss().to(device)
        with tqdm(total=len(train_loader.dataset)) as progress_bar:
            for i, (imgs, labels, captions) in enumerate(train_loader):
                start_batch = time.time()
                imgs = imgs.to(device)
                labels = labels.to(device)

                with torch.no_grad():
                    if opt.conditioning == 'unconditional':
                        condition_embd = None
                    else:
                        condition_embd = embedder(labels, captions)

                optimizer.zero_grad()

                # outputs = model.forward(imgs, condition_embd)
                # loss = outputs['loss'].mean()
                # loss.backward()
                # optimizer.step()
                z, sldj = model.forward(imgs, condition_embd, reverse=False)
                loss = loss_fn(z, sldj) / np.prod(imgs.size()[1:])
                loss.backward()
                if opt.max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, opt.max_grad_norm)
                optimizer.step()
                scheduler.step(global_step)

                batches_done = epoch * len(train_loader) + i
                writer.add_scalar('train/bpd', loss / np.log(2), batches_done)
                loss_to_log += loss.item()
                # if (i + 1) % opt.print_every == 0:
                #     loss_to_log = loss_to_log / (np.log(2) * opt.print_every)
                #     print(
                #         "[Epoch %d/%d] [Batch %d/%d] [bpd: %f] [Time/batch %.3f]"
                #         % (epoch + 1, opt.n_epochs, i + 1, len(train_loader), loss_to_log, time.time() - start_batch)
                #     )
                progress_bar.set_postfix(bpd=(loss_to_log / np.log(2)),
                                         lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(imgs.size(0))
                global_step += imgs.size(0)

                loss_to_log = 0.0

                if (batches_done + 1) % opt.sample_interval == 0:
                    print("sampling_images")
                    model = model.eval()
                    sample_image(model,
                                 embedder,
                                 opt.output_dir,
                                 n_row=4,
                                 batches_done=batches_done,
                                 dataloader=val_loader,
                                 device=device)

        val_bpd = eval(model, embedder, val_loader, opt, writer, device=device)
        writer.add_scalar("val/bpd", val_bpd, (epoch + 1) * len(train_loader))

        torch.save(
            model.state_dict(),
            os.path.join(opt.output_dir, 'models',
                         'epoch_{}.pt'.format(epoch)))
Ejemplo n.º 12
0
def train(epochs, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm):
    global global_step

    net.train()
    loss_meter = util.AverageMeter()
    evaluator = EvaluationModel()
    test_conditions = get_test_conditions(os.path.join('test.json')).to(device)
    new_test_conditions = get_test_conditions(
        os.path.join('new_test.json')).to(device)
    best_score = 0
    new_best_score = 0

    for epoch in range(1, epochs + 1):
        print('\nEpoch: ', epoch)
        with tqdm(total=len(trainloader.dataset)) as progress_bar:
            for x, cond_x in trainloader:
                x, cond_x = x.to(device, dtype=torch.float), cond_x.to(
                    device, dtype=torch.float)
                optimizer.zero_grad()
                z, sldj = net(x, cond_x, reverse=False)
                loss = loss_fn(z, sldj)
                wandb.log({'loss': loss})
                # print('loss: ',loss)
                loss_meter.update(loss.item(), x.size(0))
                # wandb.log({'loss_meter',loss_meter})
                loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
                # scheduler.step(global_step)

                progress_bar.set_postfix(nll=loss_meter.avg,
                                         bpd=util.bits_per_dim(
                                             x, loss_meter.avg),
                                         lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(x.size(0))
                global_step += x.size(0)

        net.eval()
        with torch.no_grad():
            gen_imgs = sample(net, test_conditions, device)
        score = evaluator.eval(gen_imgs, test_conditions)
        wandb.log({'score': score})
        if score > best_score:
            best_score = score
            best_model_wts = copy.deepcopy(net.state_dict())
            torch.save(
                best_model_wts,
                os.path.join('weightings/test',
                             f'epoch{epoch}_score{score:.2f}.pt'))

        with torch.no_grad():
            new_gen_imgs = sample(net, new_test_conditions, device)
        new_score = evaluator.eval(new_gen_imgs, new_test_conditions)
        wandb.log({'new_score': new_score})
        if new_score > new_best_score:
            new_best_score = score
            new_best_model_wts = copy.deepcopy(net.state_dict())
            torch.save(
                best_model_wts,
                os.path.join('weightings/new_test',
                             f'epoch{epoch}_score{score:.2f}.pt'))
        save_image(gen_imgs,
                   os.path.join('results/test', f'epoch{epoch}.png'),
                   nrow=8,
                   normalize=True)
        save_image(new_gen_imgs,
                   os.path.join('results/new_test', f'epoch{epoch}.png'),
                   nrow=8,
                   normalize=True)