示例#1
0
    def setup_train(self, model_name, model_file_path=None):
        #batcher
        train_gen = RandomGenerator(self.train_empty_img_id ,
                                    self.train_non_empty_img_id,
                                    self.config,
                                    self.masks)
        self.train_batcher = Batcher(train_gen)

        #model
        self.model = get_model(self.config)
        params = self.model.parameters()
        req_params = filter(lambda p: p.requires_grad, self.model.parameters())
        logging.info("Number of params: %d Number of params required grad: %d" % (sum(p.numel() for p in params),
                                                                                  sum(p.numel() for p in req_params)))
        #optimizer
        initial_lr = self.config['lr']
        self.optimizer = optim.Adam(req_params, lr=initial_lr)

        start_iter, start_loss = 0, 0
        if model_file_path is not None:
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            self.model.load_state_dict(state['state_dict'])

            start_iter = state['iter']
            start_loss = state['current_exp_loss']
           
            self.optimizer.load_state_dict(state['optimizer'])
            if self.config['use_cuda']:
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        return start_iter, start_loss
示例#2
0
def main():
    start_time = time()

    in_arg = get_args_train()

    data_dir = in_arg.data_dir

    device = get_device(in_arg.gpu)
    #     print(device)
    dataloaders = get_dataloaders(data_dir)

    criterion = get_criterion()

    model = get_model(device=device,
                      arch=in_arg.arch,
                      hidden_units=in_arg.hidden_units,
                      data_dir=in_arg.data_dir,
                      save_dir=in_arg.save_dir)
    # print(model)

    optimizer = get_optimizer(model, in_arg.learning_rate)
    #     print(optimizer)

    train(model,
          criterion,
          optimizer,
          epochs=in_arg.epochs,
          device=device,
          train_loader=dataloaders['train'],
          valid_loader=dataloaders['valid'])

    tot_time = time() - start_time
    print(f"\n** Total Elapsed Runtime: {tot_time:.3f} seconds")
示例#3
0
def train_ocgd(epoch_num=10,
               optim_type='BCGD2',
               startPoint=None,
               logdir='test',
               update_min=True,
               z_dim=128,
               batchsize=64,
               loss_name='WGAN',
               model_name='dc',
               data_path='None',
               dataname='cifar10',
               device='cpu',
               gpu_num=1,
               collect_info=False):
    lr_d = 0.01
    lr_g = 0.01
    dataset = get_data(dataname=dataname, path='../datas/%s' % data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim)
    D.to(device)
    G.to(device)
    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        print('Start from %s' % startPoint)

    optimizer = OCGD(max_params=G.parameters(),
                     min_params=D.parameters(),
                     udpate_min=update_min,
                     device=device)
    loss_list = []
    count = 0
    for e in range(epoch_num):
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            z = torch.randn((real_x.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            D_loss = get_loss(name=loss_name,
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake)
            optimizer.zero_grad()
            optimizer.step(loss=D_loss)
            if count % 100 == 0:
                print('Iter %d, Loss: %.5f' % (count, D_loss.item()))
                loss_list.append(D_loss.item())
            count += 1
        print('epoch{%d/%d}' % (e, epoch_num))
    name = 'overtrainD.pth' if update_min else 'overtrainG.pth'
    save_checkpoint(path=logdir, name=name, D=D, G=G)
    loss_data = pd.DataFrame(loss_list)
    loss_data.to_csv('logs/train_oneside.csv')
def _test_logger_load():
    from train_utils import get_model
    path = '../deeplabv3plusxception.pkl'
    model = get_model('deeplabv3plusxception', 21)
    # model = nn.DataParallel(model)
    # if isinstance(model, nn.DataParallel):
    #     print(True)
    # print(model)
    # model = model.module
    ModelSaver.load_tool(model=model, load_path=path)
    print(model)
示例#5
0
    def __init__(self, train_dir):
        config_file = os.path.join(train_dir, 'config.json')
        model_dir = os.path.join(train_dir, 'model')
        model_file_path = os.path.join(model_dir, 'bestmodel')

        self.config = json.load(open(config_file))
        self.model = get_model(self.config)
        state = torch.load(model_file_path,
                           map_location=lambda storage, location: storage)
        self.model.load_state_dict(state['state_dict'])
        self.model.eval()
def from_cpu_to_multi_gpu(model_name, n_calss, load_path):
    """
    transfer original model to nn.DataParallel model
    :param model_name: str
    :param n_calss: int
    :param load_path: str
    :return:
    """
    from train_utils import get_model
    model = get_model(model_name, n_calss)
    d = torch.load(load_path, map_location='cpu')
    model.load_state_dict(d)
    model = nn.DataParallel(model)
    torch.save(model.state_dict(), load_path)
    pass
示例#7
0
def setup_model(args):
    """Setup model and optimizer."""

    model = get_model(args, model_type="generation")

    # if args.deepspeed:
    #     print_rank_0("DeepSpeed is enabled.")
    #
    #     model, _, _, _ = deepspeed.initialize(
    #         model=model,
    #         model_parameters=model.parameters(),
    #         args=args,
    #         mpu=mpu,
    #         dist_init_required=False
    #     )
    if args.load_pretrained is not None:
        args.no_load_optim = True
        args.load = args.load_pretrained
        _ = load_checkpoint(model, None, None, args)
    # if args.deepspeed:
    #     model = model.module

    return model
示例#8
0
from train_utils import get_model
import torch
from ptflops import get_model_complexity_info

if __name__ == '__main__':
    z_dim = 100
    model_config = {
        'image_size': 64,
        'image_channel': 3,
        'feature_num': 64,
        'n_extra_layers': 0
    }
    with torch.cuda.device(0):
        D, G = get_model(model_name='DCGANs', z_dim=100, configs=model_config)
        macsD, paramsD = get_model_complexity_info(
            D, (model_config['image_channel'], model_config['image_size'],
                model_config['image_size']),
            as_strings=True,
            print_per_layer_stat=True,
            verbose=True)
        print('{:<30}  {:<8}'.format('Computational complexity: ', macsD))
        print('{:<30}  {:<8}'.format('Number of parameters: ', paramsD))
        macsG, paramsG = get_model_complexity_info(G, (z_dim, 1, 1),
                                                   as_strings=True,
                                                   print_per_layer_stat=True,
                                                   verbose=True)
        print('{:<30}  {:<8}'.format('Computational complexity: ', macsG))
        print('{:<30}  {:<8}'.format('Number of parameters: ', paramsG))
示例#9
0
def train_cgd(epoch_num=10,
              milestone=None,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              tols={
                  'tol': 1e-10,
                  'atol': 1e-16
              },
              l2_penalty=0.0,
              momentum=0.0,
              loss_name='WGAN',
              model_name='dc',
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='cifar10',
              device='cpu',
              gpu_num=1,
              collect_info=False):
    lr_d = 0.01
    lr_g = 0.01

    dataset = get_data(dataname=dataname, path='../datas/%s' % data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    writer = SummaryWriter(log_dir='logs/%s/%s_%.3f' %
                           (logdir, current_time, lr_d))
    if optim_type == 'BCGD':
        optimizer = BCGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         momentum=momentum,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device)
        scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    elif optim_type == 'ICR':
        optimizer = ICR(max_params=G.parameters(),
                        min_params=D.parameters(),
                        lr=lr_d,
                        alpha=1.0,
                        device=device)
        scheduler = icrScheduler(optimizer, milestone)
    elif optim_type == 'ACGD':
        optimizer = ACGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device,
                         solver='cg')
        scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optimizer.load_state_dict(chk['optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if model_name == 'DCGAN' or model_name == 'DCGAN-WBN':
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        scheduler.step(epoch=e)
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if model_name == 'DCGAN' or model_name == 'DCGAN-WBN':
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            optimizer.zero_grad()
            optimizer.step(loss)

            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(
                    path=logdir,
                    name='%s-%s%.3f_%d.pth' %
                    (optim_type, model_name, lr_g, count + start_n),
                    D=D,
                    G=G,
                    optimizer=optimizer)
            writer.add_scalars('Discriminator output', {
                'Generated image': d_fake.mean().item(),
                'Real image': d_real.mean().item()
            },
                               global_step=count)
            writer.add_scalar('Loss', loss.item(), global_step=count)
            if collect_info:
                cgd_info = optimizer.get_info()
                writer.add_scalar('Conjugate Gradient/iter num',
                                  cgd_info['iter_num'],
                                  global_step=count)
                writer.add_scalar('Conjugate Gradient/running time',
                                  cgd_info['time'],
                                  global_step=count)
                writer.add_scalars('Delta', {
                    'D gradient': cgd_info['grad_y'],
                    'G gradient': cgd_info['grad_x'],
                    'D hvp': cgd_info['hvp_y'],
                    'G hvp': cgd_info['hvp_x'],
                    'D cg': cgd_info['cg_y'],
                    'G cg': cgd_info['cg_x']
                },
                                   global_step=count)
            count += 1
    writer.close()
示例#10
0
def train(epoch_num=10,
          milestone=None,
          optim_type='Adam',
          momentum=0.5,
          lr_d=1e-4,
          lr_g=1e-4,
          startPoint=None,
          start_n=0,
          z_dim=128,
          batchsize=64,
          loss_name='WGAN',
          model_name='dc',
          model_config=None,
          data_path='None',
          show_iter=100,
          logdir='test',
          dataname='cifar10',
          device='cpu',
          gpu_num=1,
          saturating=False):
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    # writer = SummaryWriter(log_dir='logs/%s/%s' % (logdir, current_time))
    d_optimizer = Adam(D.parameters(), lr=lr_d, betas=(momentum, 0.99))
    g_optimizer = Adam(G.parameters(), lr=lr_g, betas=(momentum, 0.99))
    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        d_optimizer.load_state_dict(chk['d_optim'])
        g_optimizer.load_state_dict(chk['g_optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)

    for e in range(epoch_num):
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            d_loss = get_loss(name=loss_name,
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake)
            d_optimizer.zero_grad()
            g_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            if not saturating:
                if 'DCGAN' in model_name:
                    z = torch.randn((d_real.shape[0], z_dim, 1, 1),
                                    device=device)
                else:
                    z = torch.randn((d_real.shape[0], z_dim), device=device)
                fake_x = G(z)
                d_fake = D(fake_x)
                g_loss = get_loss(name=loss_name, g_loss=True, d_fake=d_fake)
                g_optimizer.zero_grad()
                g_loss.backward()
            else:
                g_loss = d_loss
            g_optimizer.step()

            # writer.add_scalar('Loss/D loss', d_loss.item(), count)
            # writer.add_scalar('Loss/G loss', g_loss.item(), count)
            # writer.add_scalars('Discriminator output', {'Generated image': d_fake.mean().item(),
            #                                             'Real image': d_real.mean().item()},
            #                    global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter %d, D Loss: %.5f, G loss: %.5f, time: %.2f s' %
                      (count, d_loss.item(), g_loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(path=logdir,
                                name='%s-%s_%d.pth' %
                                (optim_type, model_name, count + start_n),
                                D=D,
                                G=G,
                                optimizer=d_optimizer,
                                g_optimizer=g_optimizer)
            count += 1
def train_sim(epoch_num=10,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              l2_penalty=0.0,
              momentum=0.0,
              log=False,
              loss_name='WGAN',
              model_name='dc',
              model_config=None,
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='CIFAR10',
              device='cpu',
              gpu_num=1):
    lr_d = 1e-4
    lr_g = 1e-4
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)

    optim_d = RMSprop(D.parameters(), lr=lr_d)
    optim_g = RMSprop(G.parameters(), lr=lr_g)

    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optim_d.load_state_dict(chk['d_optim'])
        optim_g.load_state_dict(chk['g_optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            D.zero_grad()
            G.zero_grad()
            loss.backward()
            optim_d.step()
            optim_g.step()

            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(
                    path=logdir,
                    name='%s-%s%.3f_%d.pth' %
                    (optim_type, model_name, lr_g, count + start_n),
                    D=D,
                    G=G,
                    optimizer=optim_d,
                    g_optimizer=optim_g)
            if wandb and log:
                wandb.log({
                    'Real score': d_real.mean().item(),
                    'Fake score': d_fake.mean().item(),
                    'Loss': loss.item()
                })
            count += 1
def train_cgd(epoch_num=10,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              tols={
                  'tol': 1e-10,
                  'atol': 1e-16
              },
              l2_penalty=0.0,
              momentum=0.0,
              loss_name='WGAN',
              model_name='dc',
              model_config=None,
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='CIFAR10',
              device='cpu',
              gpu_num=1,
              ada_train=True,
              log=False,
              collect_info=False,
              args=None):
    lr_d = args['lr_d']
    lr_g = args['lr_g']
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    if optim_type == 'BCGD':
        optimizer = BCGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         momentum=momentum,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device)
        # scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    elif optim_type == 'ICR':
        optimizer = ICR(max_params=G.parameters(),
                        min_params=D.parameters(),
                        lr=lr_d,
                        alpha=1.0,
                        device=device)
        # scheduler = icrScheduler(optimizer, milestone)
    elif optim_type == 'ACGD':
        optimizer = ACGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device,
                         solver='cg')
        # scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        # optimizer.load_state_dict(chk['optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)

    mod = 10
    accs = torch.tensor([0.8 for _ in range(mod)])

    for e in range(epoch_num):
        # scheduler.step(epoch=e)
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            optimizer.zero_grad()
            optimizer.step(loss)

            num_correct = torch.sum(d_real > 0) + torch.sum(d_fake < 0)
            acc = num_correct.item() / (d_real.shape[0] + d_fake.shape[0])
            accs[count % mod] = acc
            acc_indicator = sum(accs) / mod
            if acc_indicator > 0.9:
                ada_ratio = 0.05
            elif acc_indicator < 0.80:
                ada_ratio = 0.1
            else:
                ada_ratio = 1.0
            if ada_train:
                optimizer.set_lr(lr_max=lr_g, lr_min=ada_ratio * lr_d)

            if count % show_iter == 0 and count != 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(path=logdir,
                                name='%s-%s_%d.pth' %
                                (optim_type, model_name, count + start_n),
                                D=D,
                                G=G,
                                optimizer=optimizer)
            if wandb and log:
                wandb.log(
                    {
                        'Real score': d_real.mean().item(),
                        'Fake score': d_fake.mean().item(),
                        'Loss': loss.item(),
                        'Acc_indicator': acc_indicator,
                        'Ada ratio': ada_ratio
                    },
                    step=count,
                )

            if collect_info and wandb:
                cgd_info = optimizer.get_info()
                wandb.log(
                    {
                        'CG iter num': cgd_info['iter_num'],
                        'CG runtime': cgd_info['time'],
                        'D gradient': cgd_info['grad_y'],
                        'G gradient': cgd_info['grad_x'],
                        'D hvp': cgd_info['hvp_y'],
                        'G hvp': cgd_info['hvp_x'],
                        'D cg': cgd_info['cg_y'],
                        'G cg': cgd_info['cg_x']
                    },
                    step=count)
            count += 1
示例#13
0
from ptflops import get_model_complexity_info

if __name__ == '__main__':
    z_dim = 128
    model_name = 'Resnet'
    model_config = {
        'image_size': 64,
        'image_channel': 3,
        'feature_num': 128,
        'n_extra_layers': 0,
        'batchnorm_d': True,
        'batchnorm_g': True
    }
    with torch.cuda.device(0):
        D, G = get_model(model_name=model_name,
                         z_dim=z_dim,
                         configs=model_config)
        macsD, paramsD = get_model_complexity_info(
            D, (model_config['image_channel'], model_config['image_size'],
                model_config['image_size']),
            as_strings=True,
            print_per_layer_stat=True,
            verbose=True)
        print('{:<30}  {:<8}'.format('Computational complexity: ', macsD))
        print('{:<30}  {:<8}'.format('Number of parameters: ', paramsD))
        macsG, paramsG = get_model_complexity_info(G, (z_dim, ),
                                                   as_strings=True,
                                                   print_per_layer_stat=True,
                                                   verbose=True)
        print('{:<30}  {:<8}'.format('Computational complexity: ', macsG))
        print('{:<30}  {:<8}'.format('Number of parameters: ', paramsG))
示例#14
0
def train_scg(config, tols, milestone, device='cpu'):
    lr_d = config['lr_d']
    lr_g = config['lr_g']
    optim_type = config['optimizer']
    z_dim = config['z_dim']
    model_name = config['model']
    epoch_num = config['epoch_num']
    show_iter = config['show_iter']
    loss_name = config['loss_type']
    l2_penalty = config['d_penalty']
    logdir = config['logdir']
    start_n = config['startn']
    dataset = get_data(dataname=config['dataset'], path='../datas/%s' % config['datapath'])
    dataloader = DataLoader(dataset=dataset, batch_size=config['batchsize'],
                            shuffle=True, num_workers=4)
    inner_loader = DataLoader(dataset=dataset, batch_size=config['batchsize'],
                              shuffle=True, num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    optimizer = SCG(max_params=G.parameters(), min_params=D.parameters(),
                    lr_max=lr_g, lr_min=lr_d,
                    tol=tols['tol'], atol=tols['atol'],
                    dataloader=inner_loader,
                    device=device, solver='cg')
    scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    if config['checkpoint'] is not None:
        startPoint = config['checkpoint']
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optimizer.load_state_dict(chk['optim'])
        print('Start from %s' % startPoint)
    gpu_num = config['gpu_num']
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))

    timer = time.time()
    count = 0
    if model_name == 'DCGAN' or model_name == 'DCGAN-WBN':
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        scheduler.step(epoch=e)
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            optimizer.zero_grad()
            real_x = real_x[0]
            if model_name == 'DCGAN' or model_name == 'DCGAN-WBN':
                z = torch.randn((real_x.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((real_x.shape[0], z_dim), device=device)
            def closure(train_x):
                train_x = train_x.to(device)
                fake_x = G(z)
                d_fake = D(fake_x)
                d_real = D(train_x)
                loss = get_loss(name=loss_name, g_loss=False,
                                d_real=d_real, d_fake=d_fake,
                                l2_weight=l2_penalty, D=D)
                return loss
            loss = optimizer.step(closure=closure, img=real_x)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs'
                      % (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (config['dataset'], logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img, path + 'iter_%d.png' % (count + start_n), normalize=True)
                save_checkpoint(path=logdir,
                                name='%s-%s%.3f_%d.pth' % (optim_type, model_name, lr_g, count + start_n),
                                D=D, G=G, optimizer=optimizer)
            count += 1