def set_model(opt): model = SupConResNet(name=opt.model) criterion = SupConLoss(temperature=opt.temp) # enable synchronized Batch Normalization if opt.syncBN: model = apex.parallel.convert_syncbn_model(model) if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model.encoder = torch.nn.DataParallel(model.encoder) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True return model, criterion
def main(opt): opt = setup_environment(opt) graph = Graph("coco") # Dataset transform = transforms.Compose([ MirrorPoses(opt.mirror_probability), FlipSequence(opt.flip_probability), RandomSelectSequence(opt.sequence_length), ShuffleSequence(opt.shuffle), PointNoise(std=opt.point_noise_std), JointNoise(std=opt.joint_noise_std), MultiInput(graph.connect_joint, opt.use_multi_branch), ToTensor() ], ) dataset_class = dataset_factory(opt.dataset) dataset = dataset_class( opt.train_data_path, train=True, sequence_length=opt.sequence_length, transform=TwoNoiseTransform(transform), ) dataset_valid = dataset_class( opt.valid_data_path, sequence_length=opt.sequence_length, transform=transforms.Compose([ SelectSequenceCenter(opt.sequence_length), MultiInput(graph.connect_joint, opt.use_multi_branch), ToTensor() ]), ) train_loader = torch.utils.data.DataLoader( dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True, shuffle=True, ) val_loader = torch.utils.data.DataLoader( dataset_valid, batch_size=opt.batch_size_validation, num_workers=opt.num_workers, pin_memory=True, ) # Model & criterion model, model_args = get_model_resgcn(graph, opt) criterion = SupConLoss(temperature=opt.temp) print("# parameters: ", count_parameters(model)) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model, opt.gpus) if opt.cuda: model.cuda() criterion.cuda() # Trainer optimizer, scheduler, scaler = get_trainer(model, opt, len(train_loader)) # Load checkpoint or weights load_checkpoint(model, optimizer, scheduler, scaler, opt) # Tensorboard writer = SummaryWriter(log_dir=opt.tb_path) sample_input = torch.zeros(opt.batch_size, model_args["num_input"], model_args["num_channel"], opt.sequence_length, graph.num_node).cuda() writer.add_graph(model, input_to_model=sample_input) best_acc = 0 loss = 0 for epoch in range(opt.start_epoch, opt.epochs + 1): # train for one epoch time1 = time.time() loss = train(train_loader, model, criterion, optimizer, scheduler, scaler, epoch, opt) time2 = time.time() print(f"epoch {epoch}, total time {time2 - time1:.2f}") # tensorboard logger writer.add_scalar("loss/train", loss, epoch) writer.add_scalar("learning_rate", optimizer.param_groups[0]["lr"], epoch) # evaluation result, accuracy_avg, sub_accuracies, dataframe = evaluate( val_loader, model, opt.evaluation_fn, use_flip=True) writer.add_text("accuracy/validation", dataframe.to_markdown(), epoch) writer.add_scalar("accuracy/validation", accuracy_avg, epoch) for key, sub_accuracy in sub_accuracies.items(): writer.add_scalar(f"accuracy/validation/{key}", sub_accuracy, epoch) print(f"epoch {epoch}, avg accuracy {accuracy_avg:.4f}") is_best = accuracy_avg > best_acc if is_best: best_acc = accuracy_avg if opt.tune: tune.report(accuracy=accuracy_avg) if epoch % opt.save_interval == 0 or ( is_best and epoch > opt.save_best_start * opt.epochs): save_file = os.path.join( opt.save_folder, f"ckpt_epoch_{'best' if is_best else epoch}.pth") save_model(model, optimizer, scheduler, scaler, opt, opt.epochs, save_file) # save the last model save_file = os.path.join(opt.save_folder, "last.pth") save_model(model, optimizer, scheduler, scaler, opt, opt.epochs, save_file) log_hyperparameter(writer, opt, best_acc, loss) print(f"best accuracy: {best_acc*100:.2f}")