def main(): cudnn.benchmark = True cudnn.enabled = True opts = options() continue_exp = opts.continue_exp model = PoseNet(nstack=opts.nstack, inp_dim=opts.inp_dim, oup_dim=opts.oup_dim, masks_flag=opts.masks_flag) #print (model) print(">>> total params: {:.2f}M".format( sum(p.numel() for p in model.parameters()) / 1000000.0)) optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr) ##train datas and valid datas loader generator## data_load_func = dataload.init(opts) save_options(opts, os.path.join('log/train_option/' + opts.exp), model.__str__(), optimizer.__str__()) begin_epoch = 0 total_epochs = opts.total_epochs #choose whether continue the specified experiment checkpoint that was saved last time or not# if continue_exp: begin_epoch = Model_Checkpoints(opts).load_checkpoints( model, optimizer) print('Start training # epoch{}'.format(begin_epoch)) for epoch in range(begin_epoch, total_epochs): print('-------------Training Epoch {}-------------'.format(epoch)) #lr = adjust_lr(optimizer, epoch) #training and validation for phase in ['train', 'valid']: if phase == 'train': num_step = opts.train_iters else: num_step = opts.valid_iters generator = data_load_func(phase) print('start', phase) show_range = range(num_step) show_range = tqdm.tqdm(show_range, total=num_step, ascii=True) for i in show_range: datas = next(generator) loss = train_func(opts, model, optimizer, phase, **datas) if i % 20 == 0 and phase == 'train': niter = epoch * num_step + i writer.add_scalar('{}/Loss'.format(phase), loss.data[0], niter) if phase == 'valid': writer.add_scalar('{}/Loss'.format(phase), loss.data[0], niter) Model_Checkpoints(opts).save_checkpoints({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }) if epoch % 50 == 0 and epoch != 0: Model_Checkpoints(opts).save_checkpoints( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, filename='{}_checkpoint.pth.tar'.format(epoch))
def main(args): flow.enable_eager_execution() train_data_loader = OFRecordDataLoader( ofrecord_root=args.ofrecord_path, mode="train", # NOTE(Liang Depeng): needs to explictly set the dataset size dataset_size=7459, batch_size=args.train_batch_size, ) val_data_loader = OFRecordDataLoader( ofrecord_root=args.ofrecord_path, mode="val", dataset_size=1990, batch_size=args.val_batch_size, ) # oneflow init start_t = time.time() posenet_module = PoseNet() if args.load_checkpoint != "": posenet_module.load_state_dict(flow.load(args.load_checkpoint)) end_t = time.time() print("init time : {}".format(end_t - start_t)) of_cross_entropy = flow.nn.CrossEntropyLoss() posenet_module.to("cuda") of_cross_entropy.to("cuda") of_sgd = flow.optim.SGD(posenet_module.parameters(), lr=args.learning_rate, momentum=args.mom) of_losses = [] all_samples = len(val_data_loader) * args.val_batch_size print_interval = 100 for epoch in range(args.epochs): posenet_module.train() for b in range(len(train_data_loader)): image, label = train_data_loader.get_batch() # oneflow train start_t = time.time() image = image.to("cuda") label = label.to("cuda") logits = posenet_module(image) loss = of_cross_entropy(logits, label) loss.backward() of_sgd.step() of_sgd.zero_grad() end_t = time.time() if b % print_interval == 0: l = loss.numpy() of_losses.append(l) print( "epoch {} train iter {} oneflow loss {}, train time : {}". format(epoch, b, l, end_t - start_t)) print("epoch %d train done, start validation" % epoch) posenet_module.eval() correct_of = 0.0 for b in range(len(val_data_loader)): image, label = val_data_loader.get_batch() start_t = time.time() image = image.to("cuda") with flow.no_grad(): logits = posenet_module(image) predictions = logits.softmax() of_predictions = predictions.numpy() clsidxs = np.argmax(of_predictions, axis=1) label_nd = label.numpy() for i in range(args.val_batch_size): if clsidxs[i] == label_nd[i]: correct_of += 1 end_t = time.time() print("epoch %d, oneflow top1 val acc: %f" % (epoch, correct_of / all_samples)) flow.save( posenet_module.state_dict(), os.path.join( args.save_checkpoint_path, "epoch_%d_val_acc_%f" % (epoch, correct_of / all_samples), ), ) writer = open("of_losses.txt", "w") for o in of_losses: writer.write("%f\n" % o) writer.close()