示例#1
0
# prepare the training and validation data loaders
train_data, valid_data = prepare_dataset(root_path='../input/catsNdogs/')
trainset = LFWDataset(train_data, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
validset = LFWDataset(valid_data, transform=transform)
validloader = DataLoader(validset, batch_size=batch_size)

train_loss = []
valid_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(model, trainloader, trainset, device, optimizer,
                             criterion)
    valid_epoch_loss, recon_images = validate(model, validloader, validset,
                                              device, criterion)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    # save the reconstructed images from the validation loop
    save_reconstructed_images(recon_images, epoch + 1)
    # convert the reconstructed images to PyTorch image grid format
    image_grid = make_grid(recon_images.detach().cpu())
    grid_images.append(image_grid)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {valid_epoch_loss:.4f}")

# save the reconstructions as a .gif file
image_to_vid(grid_images)
# save the loss plots to disk
save_loss_plot(train_loss, valid_loss)
print('TRAINING COMPLETE')
示例#2
0
def train_gan(config, dataloader, device):
    #initialize models
    gen = Generator(config).to(device)
    dis = Discriminator(config).to(device)
    gen.apply(utils.init_weights)
    dis.apply(utils.init_weights)

    #setup optimizers
    gen_optimizer = torch.optim.Adam(params=gen.parameters(),
                                     lr=config['lr'],
                                     betas=[config['beta1'], config['beta2']])
    dis_optimizer = torch.optim.Adam(params=dis.parameters(),
                                     lr=config['lr'],
                                     betas=[config['beta1'], config['beta2']])

    criterion = torch.nn.BCELoss()
    fixed_latent = torch.randn(16, config['len_z'], 1, 1, device=device)

    dis_loss = []
    gen_loss = []
    generated_imgs = []
    iteration = 0

    #load parameters
    if (config['load_params'] and os.path.isfile("./gen_params.pth.tar")):
        print("loading params...")
        gen.load_state_dict(
            torch.load("./gen_params.pth.tar",
                       map_location=torch.device(device)))
        dis.load_state_dict(
            torch.load("./dis_params.pth.tar",
                       map_location=torch.device(device)))
        gen_optimizer.load_state_dict(
            torch.load("./gen_optimizer_state.pth.tar",
                       map_location=torch.device(device)))
        dis_optimizer.load_state_dict(
            torch.load("./dis_optimizer_state.pth.tar",
                       map_location=torch.device(device)))
        generated_imgs = torch.load("gen_imgs_array.pt",
                                    map_location=torch.device(device))
        print("loaded params.")

    #training
    start_time = time.time()
    gen.train()
    dis.train()
    for epoch in range(config['epochs']):
        iterator = iter(dataloader)
        dataloader_flag = True
        while (dataloader_flag):
            for _ in range(config['discriminator_steps']):
                dis.zero_grad()
                gen.zero_grad()
                dis_optimizer.zero_grad()

                #sample mini-batch
                z = torch.randn(config['batch_size'],
                                config['len_z'],
                                1,
                                1,
                                device=device)

                #get images from dataloader via iterator
                try:
                    imgs, _ = next(iterator)
                    imgs = imgs.to(device)
                except:
                    dataloader_flag = False
                    break

                #compute loss
                loss_true_imgs = criterion(
                    dis(imgs).view(-1), torch.ones(imgs.shape[0],
                                                   device=device))
                loss_true_imgs.backward()
                fake_images = gen(z)
                loss_fake_imgs = criterion(
                    dis(fake_images.detach()).view(-1),
                    torch.zeros(z.shape[0], device=device))
                loss_fake_imgs.backward()

                total_error = loss_fake_imgs + loss_true_imgs
                dis_optimizer.step()

            #generator step
            for _ in range(config['generator_steps']):
                if (dataloader_flag == False):
                    break
                gen.zero_grad()
                dis.zero_grad()
                dis_optimizer.zero_grad()
                gen_optimizer.zero_grad()

                #z = torch.randn(config['batch_size'],config['len_z'])   #sample mini-batch
                loss_gen = criterion(
                    dis(fake_images).view(-1),
                    torch.ones(z.shape[0], device=device))  #compute loss

                #update params
                loss_gen.backward()
                gen_optimizer.step()

            iteration += 1

            #log and save variable, losses and generated images
            if (iteration % 100) == 0:
                elapsed_time = time.time() - start_time
                dis_loss.append(total_error.mean().item())
                gen_loss.append(loss_gen.mean().item())

                with torch.no_grad():
                    generated_imgs.append(
                        gen(fixed_latent).detach())  #generate image
                    torch.save(generated_imgs, "gen_imgs_array.pt")

                print(
                    "Iteration:%d, Dis Loss:%.4f, Gen Loss:%.4f, time elapsed:%.4f"
                    % (iteration, dis_loss[-1], gen_loss[-1], elapsed_time))

                if (config['save_params'] and iteration % 400 == 0):
                    print("saving params...")
                    torch.save(gen.state_dict(), "./gen_params.pth.tar")
                    torch.save(dis.state_dict(), "./dis_params.pth.tar")
                    torch.save(dis_optimizer.state_dict(),
                               "./dis_optimizer_state.pth.tar")
                    torch.save(gen_optimizer.state_dict(),
                               "./gen_optimizer_state.pth.tar")
                    print("saved params.")

    #plot errors
    utils.save_loss_plot(gen_loss, dis_loss)

    #plot generated images
    utils.save_result_images(
        next(iter(dataloader))[0][:15].to(device), generated_imgs[-1], 4,
        config)

    #save generated images so see what happened
    torch.save(generated_imgs, "gen_imgs_array.pt")

    #save gif
    utils.save_gif(generated_imgs, 4, config)
示例#3
0
def train_model(model, criterion, optimizer, scheduler, cfg):

    best_loss = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    train_losses = []
    val_losses = []
    unit_diamond_vertices = get_unit_diamond_vertices(root_path)
    for epoch in range(1, cfg['max_epoch'] + 1):

        print('-' * 60)
        print('Epoch: {} / {}'.format(epoch, cfg['max_epoch']))
        print('-' * 60)
        for phrase in ['train', 'val']:

            if phrase == 'train':
                scheduler.step()
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            ft_all, lbl_all = None, None

            for i, (centers, corners, normals, neighbor_index, targets,
                    impurity_label) in enumerate(data_loader[phrase]):

                optimizer.zero_grad()
                if use_gpu:
                    centers = Variable(torch.cuda.FloatTensor(centers.cuda()))
                    corners = Variable(torch.cuda.FloatTensor(corners.cuda()))
                    normals = Variable(torch.cuda.FloatTensor(normals.cuda()))
                    neighbor_index = Variable(
                        torch.cuda.LongTensor(neighbor_index.cuda()))
                    targets = Variable(torch.cuda.FloatTensor(targets.cuda()))
                    impurity_label = Variable(
                        torch.cuda.FloatTensor(impurity_label.cuda()))
                    unit_diamond_vertices = Variable(
                        torch.cuda.FloatTensor(unit_diamond_vertices.cuda()))
                else:
                    centers = Variable(torch.FloatTensor(centers))
                    corners = Variable(torch.FloatTensor(corners))
                    normals = Variable(torch.FloatTensor(normals))
                    neighbor_index = Variable(torch.LongTensor(neighbor_index))
                    targets = Variable(torch.FloatTensor(targets))
                    impurity_label = Variable(
                        torch.FloatTensor(impurity_label))
                    unit_diamond_vertices = Variable(
                        torch.FloatTensor(unit_diamond_vertices))

                with torch.set_grad_enabled(phrase == 'train'):
                    eps = 1e-12
                    outputs, feas = model(centers, corners, normals,
                                          neighbor_index, impurity_label)
                    #loss = criterion(outputs, targets)
                    #loss = stochastic_loss(criterion, outputs, targets)
                    loss = point_wise_L1_loss(outputs, targets,
                                              unit_diamond_vertices)
                    if phrase == 'train':
                        loss.backward()
                        optimizer.step()

                    running_loss += loss.item() * centers.size(0)

            epoch_loss = running_loss / len(data_set[phrase])

            if phrase == 'train':
                print('{} Loss: {:.4f}'.format(phrase, epoch_loss))
                train_losses.append(epoch_loss)

            if phrase == 'val':
                val_losses.append(epoch_loss)
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                if epoch % 2 == 0:
                    torch.save(copy.deepcopy(model.state_dict()),
                               root_path + '/ckpt_root/{}.pkl'.format(epoch))

                print('{} Loss: {:.4f}'.format(phrase, epoch_loss))

        save_loss_plot(train_losses, val_losses, root_path)

    return best_model_wts
示例#4
0
def train_model(model, optimizer, scheduler, cfg):

    best_loss = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    train_losses = []
    val_losses = []
    unit_diamond_vertices = get_unit_diamond_vertices(root_path)
    if use_gpu:
        unit_diamond_vertices = Variable(torch.cuda.FloatTensor(unit_diamond_vertices.cuda()))
    else:
        unit_diamond_vertices = Variable(torch.FloatTensor(unit_diamond_vertices))

    for epoch in range(1, cfg['max_epoch']):

        print('-' * 60)
        print('Epoch: {} / {}'.format(epoch, cfg['max_epoch']))
        print('-' * 60)
        for phrase in ['train', 'val']:

            if phrase == 'train':
                scheduler.step()
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            ft_all, lbl_all = None, None

            for i, (input, diamond_center_grid_point, targets, pitch, radius) in enumerate(data_loader[phrase]):

                optimizer.zero_grad()
                if use_gpu:
                    input = Variable(torch.cuda.FloatTensor(input.cuda()))
                    diamond_center_grid_point = Variable(torch.cuda.LongTensor(diamond_center_grid_point.cuda()))
                    targets = Variable(torch.cuda.FloatTensor(targets.cuda()))
                else:
                    input = Variable(torch.FloatTensor(input))
                    diamond_center_grid_point = Variable(torch.LongTensor(diamond_center_grid_point))
                    targets = Variable(torch.FloatTensor(targets))
                    
                with torch.set_grad_enabled(phrase == 'train'):
                    eps = 1e-12
                    #center_probs, pred_rot_scale = model(input,return_encoder_features = True)
                    center_probs = model(input,return_encoder_features = False)
                    #loss = regression_classification_loss(center_probs, pred_rot_scale, diamond_center_grid_point, targets[:,3:], alpha=0.5)
                    loss = unet_loss(center_probs, diamond_center_grid_point, targets[:,3:], alpha=0.5)
                    #loss = yolo_loss(center_probs, targets, diamond_center_grid_point, unit_diamond_vertices, alpha=1)
                    if phrase == 'train':
                        loss.backward()
                        optimizer.step()

                    running_loss += loss.item() * input.size(0)

            epoch_loss = running_loss / len(data_set[phrase])

            if phrase == 'train':
                print('{} Loss: {:.4f}'.format(phrase, epoch_loss))
                train_losses.append(epoch_loss)

            if phrase == 'val':
                val_losses.append(epoch_loss)
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                if epoch % 1 == 0:
                    torch.save(copy.deepcopy(model.state_dict()), root_path + '/ckpt_root/{}.pkl'.format(epoch))

                print('{} Loss: {:.4f}'.format(phrase, epoch_loss))
        
        save_loss_plot(train_losses,val_losses,root_path)

    return best_model_wts
示例#5
0
                # Append losses
                loss_history.append(losses)

            if iteration % 100 == 0:
                # Run inference on test images
                print("Computing visuals...", end=' ')
                model.compute_visuals(test_data, IM_SIZE, epoch, iteration)
                print("done")

                # Save model weights
                print("Saving model...", end=' ')
                model.save_weights(epoch, iteration)
                print("done")

        # Update learning rate at end of epoch
        print("Updating learning rate...", end=' ')
        model.update_lr()
        print("done")
        for G_pg, D_pg in zip(model.optimizer_G.param_groups,
                              model.optimizer_D.param_groups):
            print(f"New learning rate: G: {G_pg['lr']}, D: {D_pg['lr']}")

        # Print epoch duration
        epoch_duration = time.time() - epoch_start
        print(f"Epoch completed in {epoch_duration:.0f}s")

        # Save loss plot to image
        print("Saving loss plot...", end=' ')
        utils.save_loss_plot(loss_history, plot_title)
        print("done")
示例#6
0
def train(model):
    # Checks what kind of model it is training
    if model.model_type == "infoGAN":
        is_infogan = True
    else:
        is_infogan = False

    # Makes sure we have a dir to save the model and training info
    if not os.path.exists(model.save_dir):
        os.makedirs(model.save_dir)

    # Creates artificial labels that just indicates to the loss object if prediction of D should be 0 or 1
    if model.gpu_mode:
        y_real_ = Variable(torch.ones(model.batch_size,
                                      1).cuda(model.gpu_id))  # all ones
        y_fake_ = Variable(
            torch.zeros(model.batch_size, 1).cuda(model.gpu_id))  # all zeros
    else:
        y_real_ = Variable(torch.ones(model.batch_size, 1))
        y_fake_ = Variable(torch.zeros(model.batch_size, 1))

    model.D.train()  # sets discriminator in train mode

    # TRAINING LOOP
    start_time = time.time()
    print('[*] TRAINING STARTS')
    for epoch in range(model.epoch):
        model.G.train()  # sets generator in train mode
        epoch_start_time = time.time()

        # For each minibatch returned by the data_loader
        for step, (x_, _) in enumerate(model.data_loader):
            if step == model.data_loader.dataset.__len__() // model.batch_size:
                break

            # Creates a minibatch of latent vectors
            z_ = torch.rand((model.batch_size, model.z_dim))

            # Creates a minibatch of discrete and continuous codes
            c_disc_ = torch.from_numpy(
                np.random.multinomial(
                    1,
                    model.c_disc_dim * [float(1.0 / model.c_disc_dim)],
                    size=[model.batch_size])).type(torch.FloatTensor)
            for i in range(model.n_disc_code - 1):
                c_disc_ = torch.cat([
                    c_disc_,
                    torch.from_numpy(
                        np.random.multinomial(
                            1,
                            model.c_disc_dim * [float(1.0 / model.c_disc_dim)],
                            size=[model.batch_size])).type(torch.FloatTensor)
                ],
                                    dim=1)
            c_cont_ = torch.from_numpy(
                np.random.uniform(-1,
                                  1,
                                  size=(model.batch_size,
                                        model.c_cont_dim))).type(
                                            torch.FloatTensor)

            # Convert to Variables (sends to GPU if needed)
            if model.gpu_mode:
                x_ = Variable(x_.cuda(model.gpu_id))
                z_ = Variable(z_.cuda(model.gpu_id))
                c_disc_ = Variable(c_disc_.cuda(model.gpu_id))
                c_cont_ = Variable(c_cont_.cuda(model.gpu_id))
            else:
                x_ = Variable(x_)
                z_ = Variable(z_)
                c_disc_ = Variable(c_disc_)
                c_cont_ = Variable(c_cont_)

            # update D network
            model.D_optimizer.zero_grad()

            D_real, _, _ = model.D(x_, model.dataset)
            D_real_loss = model.BCE_loss(D_real, y_real_)

            G_ = model.G(z_, c_cont_, c_disc_, model.dataset)
            D_fake, _, _ = model.D(G_, model.dataset)
            D_fake_loss = model.BCE_loss(D_fake, y_fake_)

            D_loss = D_real_loss + D_fake_loss
            model.train_history['D_loss'].append(D_loss.data[0])

            D_loss.backward(retain_graph=is_infogan)
            model.D_optimizer.step()

            # update G network
            model.G_optimizer.zero_grad()

            G_ = model.G(z_, c_cont_, c_disc_, model.dataset)
            D_fake, D_cont, D_disc = model.D(G_, model.dataset)

            G_loss = model.BCE_loss(D_fake, y_real_)
            model.train_history['G_loss'].append(G_loss.data[0])

            G_loss.backward(retain_graph=is_infogan)
            model.G_optimizer.step()

            # information loss
            if is_infogan:
                disc_loss = 0
                for i, ce_loss in enumerate(model.CE_losses):
                    i0 = i * model.c_disc_dim
                    i1 = (i + 1) * model.c_disc_dim
                    disc_loss += ce_loss(D_disc[:, i0:i1],
                                         torch.max(c_disc_[:, i0:i1], 1)[1])
                cont_loss = model.MSE_loss(D_cont, c_cont_)
                info_loss = disc_loss + cont_loss
                model.train_history['info_loss'].append(info_loss.data[0])

                info_loss.backward()
                model.info_optimizer.step()

            # Prints training info every 100 steps
            if ((step + 1) % 100) == 0:
                if is_infogan:
                    print(
                        "Epoch: [{:2d}] [{:4d}/{:4d}] D_loss: {:.8f}, G_loss: {:.8f}, info_loss: {:.8f}"
                        .format((epoch + 1), (step + 1),
                                model.data_loader.dataset.__len__() //
                                model.batch_size, D_loss.data[0],
                                G_loss.data[0], info_loss.data[0]))
                else:
                    print(
                        "Epoch: [{:2d}] [{:4d}/{:4d}] D_loss: {:.8f}, G_loss: {:.8f}"
                        .format((epoch + 1), (step + 1),
                                model.data_loader.dataset.__len__() //
                                model.batch_size, D_loss.data[0],
                                G_loss.data[0]))

        model.train_history['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)

        # Saves samples
        utils.generate_samples(
            model, os.path.join(model.save_dir, "epoch{}.png".format(epoch)))

    model.train_history['total_time'].append(time.time() - start_time)
    print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
          (np.mean(model.train_history['per_epoch_time']), model.epoch,
           model.train_history['total_time'][0]))
    print("[*] TRAINING FINISHED")

    # Saves the model
    model.save()

    # Saves the plot of losses for G and D
    utils.save_loss_plot(model.train_history,
                         filename=os.path.join(model.save_dir, "curves.png"),
                         infogan=is_infogan)