예제 #1
0
def main():
    start_time = time.time()

    init_out_dir()
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'made':
        net = MADE(**vars(args))
    elif args.net == 'pixelcnn':
        net = PixelCNN(**vars(args))
    elif args.net == 'bernoulli':
        net = BernoulliMixture(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6)

    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(args.batch_size)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        log_prob = net.log_prob(sample)
        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step)
        with torch.no_grad():
            energy = ising.energy(sample, args.ham, args.lattice,
                                  args.boundary)
            loss = log_prob + beta * energy
        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / args.beta / args.L**2
            free_energy_std = loss.std() / args.beta / args.L**2
            entropy_mean = -log_prob.mean() / args.L**2
            energy_mean = energy.mean() / args.L**2
            mag = sample.mean(dim=0)
            mag_mean = mag.mean()
            mag_sqr_mean = (mag**2).mean()
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            my_log(
                'step = {}, F = {:.8g}, F_std = {:.8g}, S = {:.8g}, E = {:.8g}, M = {:.8g}, Q = {:.8g}, lr = {:.3g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    entropy_mean.item(),
                    energy_mean.item(),
                    mag_mean.item(),
                    mag_sqr_mean.item(),
                    optimizer.param_groups[0]['lr'],
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            sample_time = 0
            train_time = 0

            if args.save_sample:
                state = {
                    'sample': sample,
                    'x_hat': x_hat,
                    'log_prob': log_prob,
                    'energy': energy,
                    'loss': loss,
                }
                torch.save(state, '{}_save/{}.sample'.format(
                    args.out_filename, step))

        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state, '{}_save/{}.state'.format(
                args.out_filename, step))

        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            torchvision.utils.save_image(
                sample,
                '{}_img/{}.png'.format(args.out_filename, step),
                nrow=int(sqrt(sample.shape[0])),
                padding=0,
                normalize=True)

            if args.print_sample:
                x_hat_np = x_hat.view(x_hat.shape[0], -1).cpu().numpy()
                x_hat_std = np.std(x_hat_np, axis=0).reshape([args.L] * 2)

                x_hat_cov = np.cov(x_hat_np.T)
                x_hat_cov_diag = np.diag(x_hat_cov)
                x_hat_corr = x_hat_cov / (
                    sqrt(x_hat_cov_diag[:, None] * x_hat_cov_diag[None, :]) +
                    args.epsilon)
                x_hat_corr = np.tril(x_hat_corr, -1)
                x_hat_corr = np.max(np.abs(x_hat_corr), axis=1)
                x_hat_corr = x_hat_corr.reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nx_hat\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nx_hat_std\n{}\nx_hat_corr\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std,
                        x_hat_corr,
                        energy_count,
                    ))

            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
예제 #2
0
    xtr, xte = mnist['train_data'], mnist['valid_data']
    xtr = torch.from_numpy(xtr).to(device)
    xte = torch.from_numpy(xte).to(device)

    # construct model and ship to GPU
    hidden_list = list(map(int, args.hiddens.split(',')))
    model = MADE(xtr.size(1),
                 hidden_list,
                 xtr.size(1),
                 num_masks=args.num_masks)
    print("number of model parameters:",
          sum([np.prod(p.size()) for p in model.parameters()]))
    model.to(device)

    # set up the optimizer
    opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=45, gamma=0.1)

    # start the training
    for epoch in range(100):
        print("epoch %d" % (epoch, ))
        scheduler.step(epoch)
        run_epoch(
            'test',
            upto=5)  # run only a few batches for approximate test accuracy
        run_epoch('train')
        model.sample(epoch)

    print("optimization done. full test set eval:")
    run_epoch('test')
예제 #3
0
    print(state_filename)
    state = torch.load(state_filename, map_location=args.device)
    ignore_param(state['net'], net)
    net.load_state_dict(state['net'])

    F_sum = 0
    F_sqr_sum = 0
    S_sum = 0
    S_sqr_sum = 0
    E_sum = 0
    E_sqr_sum = 0
    start_time = time.time()
    for step in range(args.max_step):
        with torch.no_grad():
            sample, x_hat = net.sample(args.batch_size)
            log_prob = net._log_prob(sample, x_hat)
            energy = ising.energy(sample, args.model, args.lattice,
                                  args.boundary) / args.L**2
            free_energy = energy + 1 / args.beta * log_prob / args.L**2
            entropy = -log_prob / args.L**2

            F_sum += free_energy.sum().item()
            F_sqr_sum += (free_energy**2).sum().item()
            S_sum += entropy.sum().item()
            S_sqr_sum += (entropy**2).sum().item()
            E_sum += energy.sum().item()
            E_sqr_sum += (energy**2).sum().item()

        if args.print_step and (step + 1) % args.print_step == 0:
            count = args.batch_size * (step + 1)
예제 #4
0
def BuckyBall():
    start_time = time.time()

    init_out_dir()
    print_args()

    if args.ham == 'buckey':
        ham = buckyball_2(args.beta)
#    elif args.ham == 'sk':
#        ham = SKModel(args.n, args.beta, args.device, seed=args.seed)
#    elif args.ham == 'full':
#        ham = FullModel()
#    elif args.ham == 'buckey':
#        ham = buckyball_2(args.beta)
    else:
        raise ValueError('Unknown ham: {}'.format(args.ham))
    #ham.J.requires_grad = False

    net = MADE(**vars(args))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    if args.beta_anneal_to < args.beta:
        args.beta_anneal_to = args.beta
    beta = args.beta
    while beta <= args.beta_anneal_to:
        for step in range(args.max_step):
            optimizer.zero_grad()

            sample_start_time = time.time()
            with torch.no_grad():
                sample, x_hat = net.sample(args.batch_size)
            assert not sample.requires_grad
            assert not x_hat.requires_grad
            sample_time += time.time() - sample_start_time

            train_start_time = time.time()

            log_prob = net.log_prob(sample)
            with torch.no_grad():
                energy = ham.energy(sample)
                loss = log_prob + beta * energy
            assert not energy.requires_grad
            assert not loss.requires_grad
            loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
            loss_reinforce.backward()

            if args.clip_grad > 0:
                # nn.utils.clip_grad_norm_(params, args.clip_grad)
                parameters = list(filter(lambda p: p.grad is not None, params))
                max_norm = float(args.clip_grad)
                norm_type = 2
                total_norm = 0
                for p in parameters:
                    param_norm = p.grad.data.norm(norm_type)
                    total_norm += param_norm.item()**norm_type
                    total_norm = total_norm**(1 / norm_type)
                    clip_coef = max_norm / (total_norm + args.epsilon)
                    for p in parameters:
                        p.grad.data.mul_(clip_coef)

            optimizer.step()

            train_time += time.time() - train_start_time

            if args.print_step and step % args.print_step == 0:
                free_energy_mean = loss.mean() / beta / args.n
                free_energy_std = loss.std() / beta / args.n
                entropy_mean = -log_prob.mean() / args.n
                energy_mean = energy.mean() / args.n
                mag = sample.mean(dim=0)
                mag_mean = mag.mean()
                if step > 0:
                    sample_time /= args.print_step
                    train_time /= args.print_step
                used_time = time.time() - start_time
                my_log(
                    'beta = {:.3g}, # {}, F = {:.8g}, F_std = {:.8g}, S = {:.5g}, E = {:.5g}, M = {:.5g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                    .format(
                        beta,
                        step,
                        free_energy_mean.item(),
                        free_energy_std.item(),
                        entropy_mean.item(),
                        energy_mean.item(),
                        mag_mean.item(),
                        sample_time,
                        train_time,
                        used_time,
                    ))
                sample_time = 0
                train_time = 0

        with open(args.fname, 'a', newline='\n') as f:
            f.write('{} {} {:.3g} {:.8g} {:.8g} {:.8g} {:.8g}\n'.format(
                args.n,
                args.seed,
                beta,
                free_energy_mean.item(),
                free_energy_std.item(),
                energy_mean.item(),
                entropy_mean.item(),
            ))

        if args.ham == 'hop':
            ensure_dir(args.out_filename + '_sample/')
            np.savetxt('{}_sample/sample{:.2f}.txt'.format(
                args.out_filename, beta),
                       sample.cpu().numpy(),
                       delimiter=' ',
                       fmt='%d')
            np.savetxt('{}_sample/log_prob{:.2f}.txt'.format(
                args.out_filename, beta),
                       log_prob.cpu().detach().numpy(),
                       delimiter=' ',
                       fmt='%.5f')

        beta += args.beta_inc
예제 #5
0
    ignore_param(state['net'], net)
    net.load_state_dict(state['net'])

    F_sum = 0
    F_sqr_sum = 0
    S_sum = 0
    S_sqr_sum = 0
    E_sum = 0
    E_sqr_sum = 0
    M_sum = 0
    M_sqr_sum = 0
    M_quad_sum = 0
    start_time = time.time()
    for step in range(args.max_step):
        with torch.no_grad():
            sample, _ = net.sample(args.batch_size)
            log_prob = net.log_prob(sample)
            energy = ising.energy(sample, args.ham, args.lattice,
                                  args.boundary) / args.L**2
            free_energy = energy + 1 / args.beta * log_prob / args.L**2
            entropy = -log_prob / args.L**2
            mag = sample.mean(dim=[1, 2, 3])

            F_sum += free_energy.sum().item()
            F_sqr_sum += (free_energy**2).sum().item()
            S_sum += entropy.sum().item()
            S_sqr_sum += (entropy**2).sum().item()
            E_sum += energy.sum().item()
            E_sqr_sum += (energy**2).sum().item()
            M_sum += mag.abs().sum().item()
            M_sqr_sum += (mag**2).sum().item()