def predict(model_name, model_class, weight_pth, image_size, normalize): print(f'[+] predict {model_name}') model = get_model(model_class) model.load_state_dict(torch.load(weight_pth)) model.eval() tta_preprocess = [ preprocess(normalize, image_size), preprocess_hflip(normalize, image_size) ] tta_preprocess += make_transforms( [transforms.Resize((image_size + 20, image_size + 20))], [transforms.ToTensor(), normalize], five_crops(image_size)) tta_preprocess += make_transforms( [transforms.Resize((image_size + 20, image_size + 20))], [HorizontalFlip(), transforms.ToTensor(), normalize], five_crops(image_size)) print(f'[+] tta size: {len(tta_preprocess)}') data_loaders = [] for transform in tta_preprocess: test_dataset = FurnitureDataset('test', transform=transform) data_loader = DataLoader(dataset=test_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=False) data_loaders.append(data_loader) lx, px = utils.predict_tta(model, data_loaders) data = { 'lx': lx.cpu(), 'px': px.cpu(), } torch.save(data, f'{model_name}_test_prediction.pth') data_loaders = [] for transform in tta_preprocess: test_dataset = FurnitureDataset('val', transform=transform) data_loader = DataLoader(dataset=test_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=False) data_loaders.append(data_loader) lx, px = utils.predict_tta(model, data_loaders) data = { 'lx': lx.cpu(), 'px': px.cpu(), } torch.save(data, f'{model_name}_val_prediction.pth')
def train(): train_dataset = FurnitureDataset('train', transform=preprocess_with_augmentation(normalize_torch, IMAGE_SIZE)) val_dataset = FurnitureDataset('val', transform=preprocess(normalize_torch, IMAGE_SIZE)) training_data_loader = DataLoader(dataset=train_dataset, num_workers=8, batch_size=BATCH_SIZE, shuffle=True) validation_data_loader = DataLoader(dataset=val_dataset, num_workers=1, batch_size=BATCH_SIZE, shuffle=False) model = get_model() criterion = nn.CrossEntropyLoss().cuda() nb_learnable_params = sum(p.numel() for p in model.fresh_params()) print(f'[+] nb learnable params {nb_learnable_params}') lx, px = utils.predict(model, validation_data_loader) min_loss = criterion(Variable(px), Variable(lx)).data[0] lr = 0 patience = 0 for epoch in range(20): print(f'epoch {epoch}') if epoch == 1: lr = 0.00003 print(f'[+] set lr={lr}') if patience == 2: patience = 0 model.load_state_dict(torch.load('densenet201_15755.pth')) lr = lr / 10 print(f'[+] set lr={lr}') if epoch == 0: lr = 0.001 print(f'[+] set lr={lr}') optimizer = torch.optim.Adam(model.fresh_params(), lr=lr) else: optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001) running_loss = RunningMean() running_score = RunningMean() model.train() pbar = tqdm(training_data_loader, total=len(training_data_loader)) for inputs, labels in pbar: batch_size = inputs.size(0) inputs = Variable(inputs) labels = Variable(labels) if use_gpu: inputs = inputs.cuda() labels = labels.cuda() optimizer.zero_grad() outputs = model(inputs) _, preds = torch.max(outputs.data, dim=1) loss = criterion(outputs, labels) running_loss.update(loss.data[0], 1) running_score.update(torch.sum(preds != labels.data), batch_size) loss.backward() optimizer.step() pbar.set_description(f'{running_loss.value:.5f} {running_score.value:.3f}') print(f'[+] epoch {epoch} {running_loss.value:.5f} {running_score.value:.3f}') lx, px = utils.predict(model, validation_data_loader) log_loss = criterion(Variable(px), Variable(lx)) log_loss = log_loss.data[0] _, preds = torch.max(px, dim=1) accuracy = torch.mean((preds != lx).float()) print(f'[+] val {log_loss:.5f} {accuracy:.3f}') if log_loss < min_loss: torch.save(model.state_dict(), 'densenet201_15755.pth') print(f'[+] val score improved from {min_loss:.5f} to {log_loss:.5f}. Saved!') min_loss = log_loss patience = 0 else: patience += 1
import data, pr, info_fns, G_builder, misc ########## LOCAL PATHS ############ output_path = 'C:/Users/Crbn/Documents/MPRI M2/ReSys/project/output/' hema_file = 'C:/Users/Crbn/Documents/MPRI M2/ReSys/project/data/wholecells_binary.csv' ########## MAIN PARAMETERS ########### dataset='hema' #can also try others in data.py, such as '3NAND_AND_2OR' cutoff=None G=None thresh_mult = 1 ############# MAIN ################ dataa, gene_names = data.gen_data(set=dataset, include_stages=False, cutoff=cutoff, hema_file=hema_file) dataa,gene_names = misc.preprocess(dataa, gene_names) G = G_builder.build(dataa, gene_names, G=G, thresh_mult=thresh_mult) G = misc.postprocess(G) G = misc.assign_stages(G) #print(G.nodes[edge[0]]['gene'], '->', G.nodes[edge[1]]['gene']) misc.drawG(G,output_path)