def train(**kwargs): conf.parse(kwargs) # train_set = DataSet(cfg, train=True, test=False) train_set = ImageFolder(conf.TRAIN_DATA_ROOT, transform) train_loader = DataLoader(train_set, conf.BATCH_SIZE, shuffle=True, num_workers=conf.NUM_WORKERS) model = Network() if conf.LOAD_MODEL_PATH: print(conf.LOAD_MODEL_PATH) model.load_state_dict(torch.load(conf.CHECKPOINTS_ROOT + conf.LOAD_MODEL_PATH)) device = torch.device('cuda:0' if conf.USE_GPU else 'cpu') criterion = nn.CrossEntropyLoss().to(device) lr = conf.LEARNING_RATE optim = torch.optim.Adam(params=model.parameters(), lr=lr, weight_decay=conf.WEIGHT_DECAY) model.to(device) for epoch in range(conf.MAX_EPOCH): model.train() running_loss = 0 for step, (inputs, targets) in tqdm(enumerate(train_loader)): inputs, targets = inputs.to(device), targets.to(device) optim.zero_grad() outs = model(inputs) loss = criterion(outs, targets) loss.backward() optim.step() running_loss += loss.item() if step % conf.PRINT_FREQ == conf.PRINT_FREQ - 1: running_loss = running_loss / conf.PRINT_FREQ print('[%d, %5d] loss: %.3f' % (epoch + 1, step + 1, running_loss)) # vis.plot('loss', running_loss) running_loss = 0 torch.save(model.state_dict(), conf.CHECKPOINTS_ROOT + time.strftime('%Y-%m-%d-%H-%M-%S.pth')) for param_group in optim.param_groups: lr *= conf.LEARNING_RATE_DECAY param_group['lr'] = lr
############################################################################### print(f"{gct()} : Start training") best_ms = None best_f = None start_epoch = args.start_epoch + 1 end = cfg.TRAIN.EPOCH_NUM for epoch in range(start_epoch, end): epoch_start_time = time.time() train() checkpoint, val_ms = evaluate(val_data) # Save the model if the match score is the best we've seen so far. if not best_ms or val_ms >= best_ms: state = { "epoch": epoch, "state_dict": model.state_dict(), "det_optim": det_optim.state_dict(), "des_optim": des_optim.state_dict(), } filename = f"{args.save}/model/e{epoch:03d}_{checkpoint}.pth.tar" torch.save(state, filename) best_ms = val_ms best_f = filename print("-" * 96) print( "| end of epoch {:3d} | time: {:5.02f}s | val ms {:5.03f} | best ms {:5.03f} | " .format(epoch, (time.time() - epoch_start_time), val_ms, best_ms)) print("-" * 96) # Load the best saved model.
def main(): train_dataset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()) test_dataset = MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor()) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) net = Network(1, 64, 5, 10) if USE_CUDA: net = net.cuda() opt = optim.SGD(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, momentum=.9, nesterov=True) if not os.path.exists('checkpoint'): os.mkdir('checkpoint') for epoch in range(1, EPOCHS + 1): print('[Epoch %d]' % epoch) train_loss = 0 train_correct, train_total = 0, 0 start_point = time.time() for inputs, labels in train_loader: inputs, labels = Variable(inputs), Variable(labels) if USE_CUDA: inputs, labels = inputs.cuda(), labels.cuda() opt.zero_grad() preds = F.log_softmax(net(inputs), dim=1) loss = F.cross_entropy(preds, labels) loss.backward() opt.step() train_loss += loss.item() train_correct += (preds.argmax(dim=1) == labels).sum().item() train_total += len(preds) print('train-acc : %.4f%% train-loss : %.5f' % (100 * train_correct / train_total, train_loss / len(train_loader))) print('elapsed time: %ds' % (time.time() - start_point)) test_loss = 0 test_correct, test_total = 0, 0 for inputs, labels in test_loader: with torch.no_grad(): inputs, labels = Variable(inputs), Variable(labels) if USE_CUDA: inputs, labels = inputs.cuda(), labels.cuda() preds = F.softmax(net(inputs), dim=1) test_loss += F.cross_entropy(preds, labels).item() test_correct += (preds.argmax(dim=1) == labels).sum().item() test_total += len(preds) print('test-acc : %.4f%% test-loss : %.5f' % (100 * test_correct / test_total, test_loss / len(test_loader))) torch.save(net.state_dict(), './checkpoint/checkpoint-%04d.bin' % epoch)