Ejemplo n.º 1
0
def main():

    if not osp.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)
    
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    L = Lifter().to(device)
    D = Discriminator().to(device)
    T = Discriminator().to(device)

    optim_L = optim.Adam(L.parameters(), lr=args.lift_lr)
    optim_D = optim.Adam(D.parameters(), lr=args.disc_lr)
    optim_T = optim.Adam(T.parameters(), lr=args.disc_lr)

    # use 2D results from Stack Hourglass Net
    train_loader = data.DataLoader(
        H36M(length=args.length, action='all', is_train=True, use_sh_detection=True),
        batch_size=1024,
        shuffle=True,
        pin_memory=True,
    )

    test_loader = data.DataLoader(
        H36M(length=1, action='all', is_train=False, use_sh_detection=True),
        batch_size=512,
        shuffle=False,
    )

    # Logger
    logger = Logger(osp.join(args.checkpoint, 'log.txt'), title='Human3.6M')
    logger_err = Logger(osp.join(args.checkpoint, 'log_err.txt'), title='Human3.6M MPJPE err')
    logger.set_names(['2d_loss   ', '3d_loss   ', 'adv_loss   ', 'temporal_loss   '])
    logger_err.set_names(['err'])

    for epoch in range(args.epoches):

        print('\nEpoch: [%d / %d]' % (epoch+1, args.epoches))

        loss_2d, loss_3d, loss_adv, loss_t = train(train_loader, L, D, T, optim_L, optim_D, optim_T, epoch+1, device, args)

        logger.append([loss_2d, loss_3d, loss_adv, loss_t])
        
        if (epoch + 1) % args.checkpoint_save_interval == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict_L': L.state_dict(),
                'state_dict_D': D.state_dict(),
                'state_dict_T': T.state_dict(), 
            }, checkpoint=args.checkpoint)
        
        if (epoch + 1) % args.eval_interval == 0:
            ttl_err = test(test_loader, L, epoch, device, args)
            logger_err.append([ttl_err])

    logger.close()
    logger_err.close()
Ejemplo n.º 2
0
        train_log['acc@1'],
        train_log['acc@5'],
        val_log['loss'],
        val_log['acc@1'],
        val_log['acc@5'],
    ],
                    index=[
                        'epoch', 'lr', 'loss', 'acc@1', 'acc@5', 'val_loss',
                        'val_acc1', 'val_acc5'
                    ])

    log = log.append(tmp, ignore_index=True)

    if val_log['loss'] < best_loss:
        torch.save(model.state_dict(), 'vae_model_best.pt')
        torch.save(D.state_dict(), 'D_model_best.pt')
        torch.save(metric_fc.state_dict(), 'metric_best.pt')
        best_loss = val_log['loss']
        print("=> saved best model")

plt.title(
    "Encoder, Decoder (Generator) and Discriminator Loss During Training")
plt.plot(Error_log['G_Losses'], label="G")
plt.plot(Error_log['D_Losses'], label="D")
plt.plot(Error_log['E_Losses'], label="E")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

model.load_state_dict(torch.load('vae_model_best.pt'))
    os.makedirs(out_train_demo_path, exist_ok=True)
    os.makedirs(out_test_demo_path, exist_ok=True)

    # demo
    with torch.no_grad():
        # train data
        fake_train_demo = netG(G_train_demo_input).permute(2, 0, 1, 3, 4)
        save_video(fake_train_demo, opt.nframes, out_train_demo_path,
                   'generated')
        # test data
        fake_test_demo = netG(G_test_demo_input).permute(2, 0, 1, 3, 4)
        save_video(fake_test_demo, opt.nframes, out_test_demo_path,
                   'generated')

    if (epoch + 1) % 5 == 0:
        # save weight
        model_dir = os.path.join('./checkpoints', opt.checkpoint, 'weight')
        os.makedirs(model_dir, exist_ok=True)
        model_G_name = 'netG_epoch_' + str(epoch + 1) + '.pth'
        model_D_name = 'netD_epoch_' + str(epoch + 1) + '.pth'
        torch.save(netG.state_dict(), os.path.join(model_dir, model_G_name))
        torch.save(netD.state_dict(), os.path.join(model_dir, model_D_name))

    # epoch end
    now_time = time.time()
    elapsed_time = int(now_time - start_time)
    hour_minute_time = divmod(elapsed_time, 3600)
    print('---TIME---')
    print('%d hour %d minute' %
          (hour_minute_time[0], hour_minute_time[1] / 60))
Ejemplo n.º 4
0
            fid, kid = fids[l], kids[l]
            best_note = ''
            if min_fid > fid:
                min_fid = fid
                best_note = '    (best)'
                update_best_img = True
            if l < num_parallel:
                alpha = '    %.2f' % alpha_soft[l]
                img_type_str = '(%s)' % params.img_types[l][:10]
            else:
                alpha = '        '
                img_type_str = '(ens)'
            print_log('Epoch %3d %-15s   l1_avg_loss: %.5f   rl2_avg_loss: %.5f   fid: %.3f   kid: %.3f%s%s' % \
                (epoch, img_type_str, l1_avg_loss, rl2_avg_loss, fid, kid, alpha, best_note))
        print_log('')
        if update_best_img:
            os.system('cp -r %s/fake* %s' % (save_dir, save_dir_best))

    if (epoch + 1) % 100 == 0:
        torch.save(G.state_dict(),
                   os.path.join(model_dir, 'checkpoint-gen-%d.pkl' % epoch))
        torch.save(D.state_dict(),
                   os.path.join(model_dir, 'checkpoint-dis-%d.pkl' % epoch))

# Plot average losses
plot_loss(D_avg_losses, G_avg_losses, params.num_epochs, save_dir=save_dir)
# Save trained parameters of model
torch.save(G.state_dict(), os.path.join(model_dir, 'checkpoint-gen.pkl'))
torch.save(D.state_dict(), os.path.join(model_dir, 'checkpoint-dis.pkl'))
cf.logger.close()
Ejemplo n.º 5
0
class BEGAN(object):
    def __init__(self, args):
        # Misc
        self.args = args
        self.cuda = args.cuda and torch.cuda.is_available()
        self.sample_num = 100

        # Optimization
        self.epoch = args.epoch
        self.batch_size = args.batch_size
        self.D_lr = args.D_lr
        self.G_lr = args.G_lr
        self.gamma = args.gamma
        self.lambda_k = args.lambda_k
        self.Kt = 0.0
        self.global_epoch = 0
        self.global_iter = 0

        # Visualization
        self.env_name = args.env_name
        self.visdom = args.visdom
        self.port = args.port
        self.timestep = args.timestep
        self.output_dir = Path(args.output_dir).joinpath(args.env_name)
        self.visualization_init()

        # Network
        self.model_type = args.model_type
        self.n_filter = args.n_filter
        self.n_repeat = args.n_repeat
        self.image_size = args.image_size
        self.hidden_dim = args.hidden_dim
        self.fixed_z = Variable(cuda(self.sample_z(self.sample_num),
                                     self.cuda))
        self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.env_name)
        self.load_ckpt = args.load_ckpt
        self.model_init()

        # Dataset
        self.dataset = args.dataset
        self.data_loader = return_data(args)

        self.lr_step_size = len(self.data_loader['train'].dataset
                                ) // self.batch_size * self.epoch // 8

    def model_init(self):
        self.D = Discriminator(self.model_type, self.image_size,
                               self.hidden_dim, self.n_filter, self.n_repeat)
        self.G = Generator(self.model_type, self.image_size, self.hidden_dim,
                           self.n_filter, self.n_repeat)

        self.D = cuda(self.D, self.cuda)
        self.G = cuda(self.G, self.cuda)

        self.D.weight_init(mean=0.0, std=0.02)
        self.G.weight_init(mean=0.0, std=0.02)

        self.D_optim = optim.Adam(self.D.parameters(),
                                  lr=self.D_lr,
                                  betas=(0.5, 0.999))
        self.G_optim = optim.Adam(self.G.parameters(),
                                  lr=self.G_lr,
                                  betas=(0.5, 0.999))

        #self.D_optim_scheduler = lr_scheduler.ExponentialLR(self.D_optim, gamma=0.97)
        #self.G_optim_scheduler = lr_scheduler.ExponentialLR(self.G_optim, gamma=0.97)
        self.D_optim_scheduler = lr_scheduler.StepLR(self.D_optim,
                                                     step_size=1,
                                                     gamma=0.5)
        self.G_optim_scheduler = lr_scheduler.StepLR(self.G_optim,
                                                     step_size=1,
                                                     gamma=0.5)

        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)

        if self.load_ckpt:
            self.load_checkpoint()

    def visualization_init(self):
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        if self.visdom:
            self.viz_train_curves = visdom.Visdom(env=self.env_name +
                                                  '/train_curves',
                                                  port=self.port)
            self.viz_train_samples = visdom.Visdom(env=self.env_name +
                                                   '/train_samples',
                                                   port=self.port)
            self.viz_test_samples = visdom.Visdom(env=self.env_name +
                                                  '/test_samples',
                                                  port=self.port)
            self.viz_interpolations = visdom.Visdom(env=self.env_name +
                                                    '/interpolations',
                                                    port=self.port)
            self.win_moc = None

    def sample_z(self, batch_size=0, dim=0, dist='uniform'):
        if batch_size == 0:
            batch_size = self.batch_size
        if dim == 0:
            dim = self.hidden_dim

        if dist == 'normal':
            return torch.randn(batch_size, dim)
        elif dist == 'uniform':
            return torch.rand(batch_size, dim).mul(2).add(-1)
        else:
            return None

    def sample_img(self, _type='fixed', nrow=10):
        self.set_mode('eval')

        if _type == 'fixed':
            z = self.fixed_z
        elif _type == 'random':
            z = self.sample_z(self.sample_num)
            z = Variable(cuda(z, self.cuda))
        else:
            self.set_mode('train')
            return

        samples = self.unscale(self.G(z))
        samples = samples.data.cpu()

        filename = self.output_dir.joinpath(_type + ':' +
                                            str(self.global_iter) + '.jpg')
        grid = make_grid(samples, nrow=nrow, padding=2, normalize=False)
        save_image(grid, filename=filename)
        if self.visdom:
            self.viz_test_samples.image(grid,
                                        opts=dict(title=str(filename),
                                                  nrow=nrow,
                                                  factor=2))

        self.set_mode('train')
        return grid

    def set_mode(self, mode='train'):
        if mode == 'train':
            self.G.train()
            self.D.train()
        elif mode == 'eval':
            self.G.eval()
            self.D.eval()
        else:
            raise ('mode error. It should be either train or eval')

    def scheduler_step(self):
        self.D_optim_scheduler.step()
        self.G_optim_scheduler.step()

    def unscale(self, tensor):
        return tensor.mul(0.5).add(0.5)

    def save_checkpoint(self, filename='ckpt.tar'):
        model_states = {'G': self.G.state_dict(), 'D': self.D.state_dict()}
        optim_states = {
            'G_optim': self.G_optim.state_dict(),
            'D_optim': self.D_optim.state_dict()
        }
        states = {
            'iter': self.global_iter,
            'epoch': self.global_epoch,
            'args': self.args,
            'win_moc': self.win_moc,
            'fixed_z': self.fixed_z.data.cpu(),
            'model_states': model_states,
            'optim_states': optim_states
        }

        file_path = self.ckpt_dir.joinpath(filename)
        torch.save(states, file_path.open('wb+'))
        print("=> saved checkpoint '{}' (iter {})".format(
            file_path, self.global_iter))

    def load_checkpoint(self, filename='ckpt.tar'):
        file_path = self.ckpt_dir.joinpath(filename)
        if file_path.is_file():
            checkpoint = torch.load(file_path.open('rb'))
            self.global_iter = checkpoint['iter']
            self.global_epoch = checkpoint['epoch']
            self.win_moc = checkpoint['win_moc']
            self.fixed_z = checkpoint['fixed_z']
            self.fixed_z = Variable(cuda(self.fixed_z, self.cuda))
            self.G.load_state_dict(checkpoint['model_states']['G'])
            self.D.load_state_dict(checkpoint['model_states']['D'])
            self.G_optim.load_state_dict(checkpoint['optim_states']['G_optim'])
            self.D_optim.load_state_dict(checkpoint['optim_states']['D_optim'])
            print("=> loaded checkpoint '{} (iter {})'".format(
                file_path, self.global_iter))
        else:
            print("=> no checkpoint found at '{}'".format(file_path))

    def train(self):
        self.set_mode('train')

        for e in range(self.epoch):
            self.global_epoch += 1
            e_elapsed = time.time()

            for idx, (images, labels) in enumerate(self.data_loader['train']):
                self.global_iter += 1

                # Discriminator Training
                x_real = Variable(cuda(images, self.cuda))
                D_real = self.D(x_real)
                D_loss_real = F.l1_loss(D_real, x_real)

                z = self.sample_z()
                z = Variable(cuda(z, self.cuda))
                x_fake = self.G(z)
                D_fake = self.D(x_fake.detach())
                D_loss_fake = F.l1_loss(D_fake, x_fake)

                D_loss = D_loss_real - self.Kt * D_loss_fake

                self.D_optim.zero_grad()
                D_loss.backward()
                self.D_optim.step()

                # Generator Training
                z = self.sample_z()
                z = Variable(cuda(z, self.cuda))
                x_fake = self.G(z)
                D_fake = self.D(x_fake)

                G_loss = F.l1_loss(x_fake, D_fake)

                self.G_optim.zero_grad()
                G_loss.backward()
                self.G_optim.step()

                # Kt update
                balance = (self.gamma * D_loss_real - D_loss_fake).data[0]
                self.Kt = max(min(self.Kt + self.lambda_k * balance, 1.0), 0.0)

                # Visualize process
                if self.visdom and self.global_iter % 1000 == 0:
                    self.viz_train_samples.images(
                        self.unscale(x_fake).data.cpu(),
                        opts=dict(
                            title='x_fake:{:d}'.format(self.global_iter)))
                    self.viz_train_samples.images(
                        self.unscale(D_fake).data.cpu(),
                        opts=dict(
                            title='D_fake:{:d}'.format(self.global_iter)))
                    self.viz_train_samples.images(
                        self.unscale(x_real).data.cpu(),
                        opts=dict(
                            title='x_real:{:d}'.format(self.global_iter)))
                    self.viz_train_samples.images(
                        self.unscale(D_real).data.cpu(),
                        opts=dict(
                            title='D_real:{:d}'.format(self.global_iter)))

                if self.visdom and self.global_iter % 10 == 0:
                    self.interpolation(self.fixed_z[0:1], self.fixed_z[1:2])
                    self.sample_img('fixed')
                    self.sample_img('random')
                    self.save_checkpoint()

                if self.visdom and self.global_iter % self.timestep == 0:
                    # Measure of Convergence
                    M_global = (D_loss_real.data + abs(balance)).cpu()

                    X = torch.Tensor([self.global_iter])
                    if self.win_moc is None:
                        self.win_moc = self.viz_train_curves.line(
                            X=X,
                            Y=M_global,
                            opts=dict(title='MOC',
                                      fillarea=True,
                                      xlabel='iteration',
                                      ylabel='Measure of Convergence'))
                    else:
                        self.win_moc = self.viz_train_curves.line(
                            X=X, Y=M_global, win=self.win_moc, update='append')

                if self.global_iter % 1000 == 0:
                    print()
                    print('iter:{:d}, M:{:.3f}'.format(self.global_iter,
                                                       M_global[0]))
                    print(
                        'D_loss_real:{:.3f}, D_loss_fake:{:.3f}, G_loss:{:.3f}'
                        .format(D_loss_real.data[0], D_loss_fake.data[0],
                                G_loss.data[0]))

                if self.global_iter % self.lr_step_size == 0:
                    self.scheduler_step()

            e_elapsed = (time.time() - e_elapsed)
            print()
            print('epoch {:d}, [{:.2f}s]'.format(self.global_epoch, e_elapsed))

        print("[*] Training Finished!")

    def interpolation(self, z1, z2, n_step=10):
        self.set_mode('eval')
        filename = self.output_dir.joinpath('interpolation' + ':' +
                                            str(self.global_iter) + '.jpg')

        step_size = (z2 - z1) / (n_step + 1)
        buff = z1
        for i in range(1, n_step + 1):
            _next = z1 + step_size * (i)
            buff = torch.cat([buff, _next], dim=0)
        buff = torch.cat([buff, z2], dim=0)

        samples = self.unscale(self.G(buff))
        grid = make_grid(samples.data.cpu(),
                         nrow=n_step + 2,
                         padding=1,
                         pad_value=0,
                         normalize=False)
        save_image(grid, filename=filename)
        if self.visdom:
            self.viz_interpolations.image(grid,
                                          opts=dict(title=str(filename),
                                                    factor=2))

        self.set_mode('train')

    def random_interpolation(self, n_step=10):
        self.set_mode('eval')
        z1 = self.sample_z(1)
        z1 = Variable(cuda(z1, self.cuda))

        z2 = self.sample_z(1)
        z2 = Variable(cuda(z2, self.cuda))

        self.interpolation(z1, z2, n_step)
        self.set_mode('train')
Ejemplo n.º 6
0
Archivo: main.py Proyecto: zeta1999/CEN
        for l in range(len(l1_avg_losses)):
            l1_avg_loss, rl2_avg_loss = l1_avg_losses[l], l2_avg_losses[l]** 0.5
            fid, kid = fids[l], kids[l]
            best_note = ''
            if min_fid > fid:
                min_fid = fid
                best_note = '    (best)'
                update_best_img = True
            if l < num_parallel:
                alpha = '    %.2f' % alpha_soft[l]
                img_type_str = '(%s)' % params.img_types[l][:10]
            else:
                alpha = '        '
                img_type_str = '(ens)'
            print_log('Epoch %3d %-15s   l1_avg_loss: %.5f   rl2_avg_loss: %.5f   fid: %.3f   kid: %.3f%s%s' % \
                (epoch, img_type_str, l1_avg_loss, rl2_avg_loss, fid, kid, alpha, best_note))
        print_log('')
        if update_best_img:
            os.system('cp -r %s/fake* %s' % (save_dir, save_dir_best))
    
    if (epoch + 1) % 100 == 0:
        torch.save(G.state_dict(), os.path.join(model_dir, 'checkpoint-gen-%d.pkl' % epoch))
        torch.save(D.state_dict(), os.path.join(model_dir, 'checkpoint-dis-%d.pkl' % epoch))

# Plot average losses
plot_loss(D_avg_losses, G_avg_losses, params.num_epochs, save_dir=save_dir)
# Save trained parameters of model
torch.save(G.state_dict(), os.path.join(model_dir, 'checkpoint-gen.pkl'))
torch.save(D.state_dict(), os.path.join(model_dir, 'checkpoint-dis.pkl'))
cf.logger.close()