def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='configs/config.json') parser.add_argument('--no-cuda', action='store_true') parser.add_argument('--parallel', action='store_true') args = parser.parse_args() args.cuda = torch.cuda.is_available() and not args.no_cuda print(args) device = torch.device('cuda' if args.cuda else 'cpu') config = load_json(args.config) model = MNISTNet() if args.parallel: model = nn.DataParallel(model) model.to(device) optimizer = optim.Adam(model.parameters(), **config['adam']) scheduler = optim.lr_scheduler.StepLR(optimizer, **config['steplr']) train_loader, valid_loader = mnist_loader(**config['dataset']) trainer = Trainer(model, optimizer, train_loader, valid_loader, device) output_dir = os.path.join(config['output_dir'], datetime.now().strftime('%Y%m%d_%H%M%S')) os.makedirs(output_dir, exist_ok=True) # save config to output dir save_json(config, os.path.join(output_dir, 'config.json')) for epoch in range(config['epochs']): scheduler.step() train_loss, train_acc = trainer.train() valid_loss, valid_acc = trainer.validate() print( 'epoch: {}/{},'.format(epoch + 1, config['epochs']), 'train loss: {:.4f}, train acc: {:.2f}%,'.format( train_loss, train_acc * 100), 'valid loss: {:.4f}, valid acc: {:.2f}%'.format( valid_loss, valid_acc * 100)) torch.save( model.state_dict(), os.path.join(output_dir, 'model_{:04d}.pt'.format(epoch + 1)))
def run(config, norm2d): train_loader, valid_loader = cifar10_loader(config.root, config.batch_size) model = CIFAR10Net(norm2d=norm2d) if config.cuda: model.cuda() optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=1e-4) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) trainer = Trainer(model, optimizer, train_loader, valid_loader, use_cuda=config.cuda) valid_acc_list = [] for epoch in range(config.epochs): start = time() scheduler.step() train_loss, train_acc = trainer.train(epoch) valid_loss, valid_acc = trainer.validate() print( 'epoch: {}/{},'.format(epoch + 1, config.epochs), 'train loss: {:.4f}, train acc: {:.2f}%,'.format( train_loss, train_acc * 100), 'valid loss: {:.4f}, valid acc: {:.2f}%,'.format( valid_loss, valid_acc * 100), 'time: {:.2f}s'.format(time() - start)) save_dir = os.path.join(config.save_dir, norm2d) os.makedirs(save_dir, exist_ok=True) torch.save(model.state_dict(), os.path.join(save_dir, 'model_{:04d}.pt'.format(epoch + 1))) valid_acc_list.append(valid_acc) return valid_acc_list