args.channel, 5, 4, args.n_res_block, args.n_res_channel, attention=False, dropout=args.dropout, n_cond_res_block=args.n_cond_res_block, cond_res_channel=args.n_res_channel, ) if 'model' in ckpt: model.load_state_dict(ckpt['model']) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) if amp is not None: model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp) model = nn.DataParallel(model) model = model.to(device) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler(optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None) logging = init_logging('log')
loader.set_description( (f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' f'acc: {accuracy:.5f}') ) class PixelTransform: def __init__(self): pass def __call__(self, input): ar = np.array(input) return torch.from_numpy(ar).long() if __name__ == '__main__': device = 'cuda' epoch = 10 dataset = datasets.MNIST('.', transform=PixelTransform(), download=True) loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=1e-3) for i in range(10): train(i, loader, model, optimizer, device) torch.save(model.state_dict(), f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')