for class_label in np.unique(y): idx = y == class_label acc = (labels[idx] == y[idx]).astype(np.float).mean() * 100 print('accuracy for class', class_label, 'is', acc) acc = (labels == y).mean() * 100 new_preds = np.zeros((len(preds),)) temp = preds[labels != 0, 1:] new_preds[labels != 0] = temp.sum(1) new_preds[labels == 0] = 1 - preds[labels == 0, 0] y = np.array(y) y[y != 0] = 1 auc_score = alaska_weighted_auc(y, new_preds) print( f'Val Loss: {epoch_loss:.3}, Weighted AUC:{auc_score:.3}, Acc: {acc:.3}') torch.save(model.state_dict(), f"epoch_{epoch}_val_loss_{epoch_loss:.3}_auc_{auc_score:.3}_rgb.pth") test_ids = os.listdir(os.path.join(PATH, 'Test')) for i in range(len(test_ids)): test_ids[i] = os.path.join(os.path.join(PATH, 'Test'), test_ids[i]) test_dataset = Alaska2Dataset(test_ids, None, augmentations=AUGMENTATIONS_TEST, test=True, color_mode=color_mode) batch_size = 16 num_workers = 0 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False,
def main(): logger.info(args) assert os.path.isdir(CONFIGS["DATA"]["DIR"]) if CONFIGS['TRAIN']['SEED'] is not None: random.seed(CONFIGS['TRAIN']['SEED']) torch.manual_seed(CONFIGS['TRAIN']['SEED']) cudnn.deterministic = True model = Net(numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], backbone=CONFIGS["MODEL"]["BACKBONE"]) if CONFIGS["TRAIN"]["DATA_PARALLEL"]: logger.info("Model Data Parallel") model = nn.DataParallel(model).cuda() else: model = model.cuda(device=CONFIGS["TRAIN"]["GPU_ID"]) # optimizer optimizer = torch.optim.Adam( model.parameters(), lr=CONFIGS["OPTIMIZER"]["LR"], weight_decay=CONFIGS["OPTIMIZER"]["WEIGHT_DECAY"]) # learning rate scheduler scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=CONFIGS["OPTIMIZER"]["STEPS"], gamma=CONFIGS["OPTIMIZER"]["GAMMA"]) best_acc1 = 0 if args.resume: if isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] model.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: logger.info("=> no checkpoint found at '{}'".format(args.resume)) # dataloader train_loader = get_loader(CONFIGS["DATA"]["DIR"], CONFIGS["DATA"]["LABEL_FILE"], batch_size=CONFIGS["DATA"]["BATCH_SIZE"], num_thread=CONFIGS["DATA"]["WORKERS"], split='train') val_loader = get_loader(CONFIGS["DATA"]["VAL_DIR"], CONFIGS["DATA"]["VAL_LABEL_FILE"], batch_size=1, num_thread=CONFIGS["DATA"]["WORKERS"], split='val') logger.info("Data loading done.") # Tensorboard summary writer = SummaryWriter(log_dir=os.path.join(CONFIGS["MISC"]["TMP"])) start_epoch = 0 best_acc = best_acc1 is_best = False start_time = time.time() if CONFIGS["TRAIN"]["RESUME"] is not None: raise (NotImplementedError) if CONFIGS["TRAIN"]["TEST"]: validate(val_loader, model, 0, writer, args) return logger.info("Start training.") for epoch in range(start_epoch, CONFIGS["TRAIN"]["EPOCHS"]): train(train_loader, model, optimizer, epoch, writer, args) acc = validate(val_loader, model, epoch, writer, args) #return scheduler.step() if best_acc < acc: is_best = True best_acc = acc else: is_best = False save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc, 'optimizer': optimizer.state_dict() }, is_best, path=CONFIGS["MISC"]["TMP"]) t = time.time() - start_time elapsed = DayHourMinute(t) t /= (epoch + 1) - start_epoch # seconds per epoch t = (CONFIGS["TRAIN"]["EPOCHS"] - epoch - 1) * t remaining = DayHourMinute(t) logger.info( "Epoch {0}/{1} finishied, auxiliaries saved to {2} .\t" "Elapsed {elapsed.days:d} days {elapsed.hours:d} hours {elapsed.minutes:d} minutes.\t" "Remaining {remaining.days:d} days {remaining.hours:d} hours {remaining.minutes:d} minutes." .format(epoch, CONFIGS["TRAIN"]["EPOCHS"], CONFIGS["MISC"]["TMP"], elapsed=elapsed, remaining=remaining)) logger.info("Optimization done, ALL results saved to %s." % CONFIGS["MISC"]["TMP"])