def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpoint, num_threads, batch_size, max_epochs, learning_rate, lr_drop_epochs, lr_drop_ratio, momentum, boost_none, none_count_scale, max_load_candidates, coverage_thr, save_freq, use_cuda): logger = logging.getLogger(train_neusomatic.__name__) logger.info("----------------Train NeuSomatic Network-------------------") if not use_cuda: torch.set_num_threads(num_threads) data_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) num_channels = 119 if args.ensemble else 26 net = NeuSomaticNet(num_channels) if use_cuda: net.cuda() if torch.cuda.device_count() > 1: logger.info("We use {} GPUs!".format(torch.cuda.device_count())) net = nn.DataParallel(net) if not os.path.exists("{}/models/".format(out_dir)): os.mkdir("{}/models/".format(out_dir)) if checkpoint: logger.info( "Load pretrained model from checkpoint {}".format(checkpoint)) pretrained_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) pretrained_state_dict = pretrained_dict["state_dict"] tag = pretrained_dict["tag"] sofar_epochs = pretrained_dict["epoch"] logger.info( "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs)) coverage_thr = pretrained_dict["coverage_thr"] logger.info( "Override coverage_thr from pretrained checkpoint: {}".format( coverage_thr)) prev_epochs = sofar_epochs + 1 model_dict = net.state_dict() # 1. filter out unnecessary keys # pretrained_state_dict = { # k: v for k, v in pretrained_state_dict.items() if k in model_dict} if "module." in pretrained_state_dict.keys( )[0] and "module." not in model_dict.keys()[0]: pretrained_state_dict = { k.split("module.")[1]: v for k, v in pretrained_state_dict.items() if k.split("module.")[1] in model_dict } elif "module." not in pretrained_state_dict.keys( )[0] and "module." in model_dict.keys()[0]: pretrained_state_dict = {("module." + k): v for k, v in pretrained_state_dict.items() if ("module." + k) in model_dict} else: pretrained_state_dict = { k: v for k, v in pretrained_state_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_state_dict) # 3. load the new state dict net.load_state_dict(pretrained_state_dict) else: prev_epochs = 0 time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") tag = "neusomatic_{}".format(time_now) logger.info("tag: {}".format(tag)) train_set = NeuSomaticDataset(roots=candidates_tsv, max_load_candidates=max_load_candidates, transform=data_transform, is_test=False, num_threads=num_threads, coverage_thr=coverage_thr) none_indices = train_set.get_none_indices() var_indices = train_set.get_var_indices() if none_indices: none_indices = map(lambda i: none_indices[i], torch.randperm(len(none_indices)).tolist()) logger.info("Non-somatic candidates: {}".format(len(none_indices))) if var_indices: var_indices = map(lambda i: var_indices[i], torch.randperm(len(var_indices)).tolist()) logger.info("Somatic candidates: {}".format(len(var_indices))) none_count = min(len(none_indices), len(var_indices) * none_count_scale) logger.info("Non-somatic considered in each epoch: {}".format(none_count)) sampler = SubsetNoneSampler(none_indices, var_indices, none_count) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_threads, pin_memory=True, sampler=sampler) logger.info("#Train cadidates: {}".format(len(train_set))) if validation_candidates_tsv: validation_set = NeuSomaticDataset( roots=validation_candidates_tsv, max_load_candidates=max_load_candidates, transform=data_transform, is_test=True, num_threads=num_threads, coverage_thr=coverage_thr) validation_loader = torch.utils.data.DataLoader( validation_set, batch_size=batch_size, shuffle=True, num_workers=num_threads, pin_memory=True) logger.info("#Validation candidates: {}".format(len(validation_set))) weights_type, weights_length = make_weights_for_balanced_classes( train_set.count_class_t, train_set.count_class_l, 4, 4, none_count) weights_type[2] *= boost_none weights_length[0] *= boost_none logger.info("weights_type:{}, weights_length:{}".format( weights_type, weights_length)) loss_s = [] gradients = torch.FloatTensor(weights_type) gradients2 = torch.FloatTensor(weights_length) if use_cuda: gradients = gradients.cuda() gradients2 = gradients2.cuda() criterion_crossentropy = nn.CrossEntropyLoss(gradients) criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2) criterion_smoothl1 = nn.SmoothL1Loss() optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) net.train() len_train_set = none_count + len(var_indices) logger.info("Number of candidater per epoch: {}".format(len_train_set)) print_freq = max(1, int(len_train_set / float(batch_size) / 4.0)) curr_epoch = int(round( len(loss_s) / float(len_train_set) * batch_size)) + prev_epochs torch.save( { "state_dict": net.state_dict(), "tag": tag, "epoch": curr_epoch, "coverage_thr": coverage_thr }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) # loop over the dataset multiple times for epoch in range(max_epochs - prev_epochs): running_loss = 0.0 for i, data in enumerate(train_loader, 0): # get the inputs (inputs, labels, var_pos_s, var_len_s, _), _ = data # wrap them in Variable inputs, labels, var_pos_s, var_len_s = Variable(inputs), Variable( labels), Variable(var_pos_s), Variable(var_len_s) if use_cuda: inputs, labels, var_pos_s, var_len_s = inputs.cuda( ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda() # zero the parameter gradients optimizer.zero_grad() outputs, _ = net(inputs) [outputs_classification, outputs_pos, outputs_len] = outputs var_len_labels = Variable( torch.LongTensor(var_len_s.cpu().data.numpy())) if use_cuda: var_len_labels = var_len_labels.cuda() loss = criterion_crossentropy( outputs_classification, labels) + 1 * criterion_smoothl1( outputs_pos, var_pos_s[:, 1] ) + 1 * criterion_crossentropy2(outputs_len, var_len_labels) loss.backward() optimizer.step() loss_s.append(loss.data[0]) running_loss += loss.data[0] if i % print_freq == print_freq - 1: logger.info('epoch: {}, i: {:>5}, lr: {}, loss: {:.5f}'.format( epoch + 1 + prev_epochs, i + 1, learning_rate, running_loss / print_freq)) running_loss = 0.0 curr_epoch = int(round( len(loss_s) / float(len_train_set) * batch_size)) + prev_epochs if curr_epoch % save_freq == 0: torch.save( { "state_dict": net.state_dict(), "tag": tag, "epoch": curr_epoch, "coverage_thr": coverage_thr, }, '{}/models/checkpoint_{}_epoch{}.pth'.format( out_dir, tag, curr_epoch)) if validation_candidates_tsv: test(net, curr_epoch, validation_loader, use_cuda) if curr_epoch % lr_drop_epochs == 0: learning_rate *= lr_drop_ratio optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) logger.info('Finished Training') curr_epoch = int(round( len(loss_s) / float(len_train_set) * batch_size)) + prev_epochs torch.save( { "state_dict": net.state_dict(), "tag": tag, "epoch": curr_epoch, "coverage_thr": coverage_thr }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) if validation_candidates_tsv: test(net, curr_epoch, validation_loader, use_cuda) logger.info("Total Epochs: {}".format(curr_epoch)) return '{}/models/checkpoint_{}_epoch{}.pth'.format( out_dir, tag, curr_epoch)
def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpoint, num_threads, batch_size, max_epochs, learning_rate, lr_drop_epochs, lr_drop_ratio, momentum, boost_none, none_count_scale, max_load_candidates, coverage_thr, save_freq, ensemble, merged_candidates_per_tsv, merged_max_num_tsvs, overwrite_merged_tsvs, trian_split_len, use_cuda): logger = logging.getLogger(train_neusomatic.__name__) logger.info("----------------Train NeuSomatic Network-------------------") logger.info("PyTorch Version: {}".format(torch.__version__)) logger.info("Torchvision Version: {}".format(torchvision.__version__)) if not os.path.exists(out_dir): os.mkdir(out_dir) if not use_cuda: torch.set_num_threads(num_threads) data_transform = transforms.Compose( [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) num_channels = 119 if ensemble else 26 net = NeuSomaticNet(num_channels) if use_cuda: logger.info("GPU training!") net.cuda() else: logger.info("CPU training!") if torch.cuda.device_count() > 1: logger.info("We use {} GPUs!".format(torch.cuda.device_count())) net = nn.DataParallel(net) if not os.path.exists("{}/models/".format(out_dir)): os.mkdir("{}/models/".format(out_dir)) if checkpoint: logger.info( "Load pretrained model from checkpoint {}".format(checkpoint)) pretrained_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) pretrained_state_dict = pretrained_dict["state_dict"] tag = pretrained_dict["tag"] sofar_epochs = pretrained_dict["epoch"] logger.info( "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs)) coverage_thr = pretrained_dict["coverage_thr"] logger.info( "Override coverage_thr from pretrained checkpoint: {}".format( coverage_thr)) prev_epochs = sofar_epochs + 1 model_dict = net.state_dict() # 1. filter out unnecessary keys # pretrained_state_dict = { # k: v for k, v in pretrained_state_dict.items() if k in model_dict} if "module." in list( pretrained_state_dict.keys())[0] and "module." not in list( model_dict.keys())[0]: pretrained_state_dict = { k.split("module.")[1]: v for k, v in pretrained_state_dict.items() if k.split("module.")[1] in model_dict } elif "module." not in list( pretrained_state_dict.keys())[0] and "module." in list( model_dict.keys())[0]: pretrained_state_dict = {("module." + k): v for k, v in pretrained_state_dict.items() if ("module." + k) in model_dict} else: pretrained_state_dict = { k: v for k, v in pretrained_state_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_state_dict) # 3. load the new state dict net.load_state_dict(pretrained_state_dict) else: prev_epochs = 0 time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") tag = "neusomatic_{}".format(time_now) logger.info("tag: {}".format(tag)) shuffle(candidates_tsv) if len(candidates_tsv) > merged_max_num_tsvs: candidates_tsv = merge_tsvs( input_tsvs=candidates_tsv, out=out_dir, candidates_per_tsv=merged_candidates_per_tsv, max_num_tsvs=merged_max_num_tsvs, overwrite_merged_tsvs=overwrite_merged_tsvs, keep_none_types=True) Ls = [] for tsv in candidates_tsv: idx = pickle.load(open(tsv + ".idx", "rb")) Ls.append(len(idx) - 1) Ls, candidates_tsv = list( zip(*sorted(zip(Ls, candidates_tsv), key=lambda x: x[0], reverse=True))) train_split_tsvs = [] current_L = 0 current_split_tsvs = [] for i, (L, tsv) in enumerate(zip(Ls, candidates_tsv)): current_L += L current_split_tsvs.append(tsv) if current_L >= trian_split_len or (i == len(candidates_tsv) - 1 and current_L > 0): logger.info("tsvs in split {}: {}".format(len(train_split_tsvs), current_split_tsvs)) train_split_tsvs.append(current_split_tsvs) current_L = 0 current_split_tsvs = [] assert sum(map(lambda x: len(x), train_split_tsvs)) == len(candidates_tsv) train_sets = [] none_counts = [] var_counts = [] none_indices_ = [] var_indices_ = [] samplers = [] for split_i, tsvs in enumerate(train_split_tsvs): train_set = NeuSomaticDataset( roots=tsvs, max_load_candidates=int(max_load_candidates * len(tsvs) / float(len(candidates_tsv))), transform=data_transform, is_test=False, num_threads=num_threads, coverage_thr=coverage_thr) train_sets.append(train_set) none_indices = train_set.get_none_indices() var_indices = train_set.get_var_indices() if none_indices: none_indices = list( map(lambda i: none_indices[i], torch.randperm(len(none_indices)).tolist())) logger.info("Non-somatic candidates is split {}: {}".format( split_i, len(none_indices))) if var_indices: var_indices = list( map(lambda i: var_indices[i], torch.randperm(len(var_indices)).tolist())) logger.info("Somatic candidates in split {}: {}".format( split_i, len(var_indices))) none_count = max( min(len(none_indices), len(var_indices) * none_count_scale), 1) logger.info( "Non-somatic considered in each epoch of split {}: {}".format( split_i, none_count)) sampler = SubsetNoneSampler(none_indices, var_indices, none_count) samplers.append(sampler) none_counts.append(none_count) var_counts.append(len(var_indices)) var_indices_.append(var_indices) none_indices_.append(none_indices) logger.info("# Total Train cadidates: {}".format( sum(map(lambda x: len(x), train_sets)))) if validation_candidates_tsv: validation_set = NeuSomaticDataset( roots=validation_candidates_tsv, max_load_candidates=max_load_candidates, transform=data_transform, is_test=True, num_threads=num_threads, coverage_thr=coverage_thr) validation_loader = torch.utils.data.DataLoader( validation_set, batch_size=batch_size, shuffle=True, num_workers=num_threads, pin_memory=True) logger.info("#Validation candidates: {}".format(len(validation_set))) count_class_t = [0] * 4 count_class_l = [0] * 4 for train_set in train_sets: for i in range(4): count_class_t[i] += train_set.count_class_t[i] count_class_l[i] += train_set.count_class_l[i] weights_type, weights_length = make_weights_for_balanced_classes( count_class_t, count_class_l, 4, 4, sum(none_counts)) weights_type[2] *= boost_none weights_length[0] *= boost_none logger.info("weights_type:{}, weights_length:{}".format( weights_type, weights_length)) loss_s = [] gradients = torch.FloatTensor(weights_type) gradients2 = torch.FloatTensor(weights_length) if use_cuda: gradients = gradients.cuda() gradients2 = gradients2.cuda() criterion_crossentropy = nn.CrossEntropyLoss(gradients) criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2) criterion_smoothl1 = nn.SmoothL1Loss() optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) net.train() len_train_set = sum(none_counts) + sum(var_counts) logger.info("Number of candidater per epoch: {}".format(len_train_set)) print_freq = max(1, int(len_train_set / float(batch_size) / 4.0)) curr_epoch = prev_epochs torch.save( { "state_dict": net.state_dict(), "tag": tag, "epoch": curr_epoch, "coverage_thr": coverage_thr }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) if len(train_sets) == 1: train_sets[0].open_candidate_tsvs() train_loader = torch.utils.data.DataLoader(train_sets[0], batch_size=batch_size, num_workers=num_threads, pin_memory=True, sampler=samplers[0]) # loop over the dataset multiple times n_epoch = 0 for epoch in range(max_epochs - prev_epochs): n_epoch += 1 running_loss = 0.0 i_ = 0 for split_i, train_set in enumerate(train_sets): if len(train_sets) > 1: train_set.open_candidate_tsvs() train_loader = torch.utils.data.DataLoader( train_set, batch_size=batch_size, num_workers=num_threads, pin_memory=True, sampler=samplers[split_i]) for i, data in enumerate(train_loader, 0): i_ += 1 # get the inputs (inputs, labels, var_pos_s, var_len_s, _), _ = data # wrap them in Variable inputs, labels, var_pos_s, var_len_s = Variable( inputs), Variable(labels), Variable(var_pos_s), Variable( var_len_s) if use_cuda: inputs, labels, var_pos_s, var_len_s = inputs.cuda( ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda() # zero the parameter gradients optimizer.zero_grad() outputs, _ = net(inputs) [outputs_classification, outputs_pos, outputs_len] = outputs var_len_labels = Variable( torch.LongTensor(var_len_s.cpu().data.numpy())) if use_cuda: var_len_labels = var_len_labels.cuda() loss = criterion_crossentropy( outputs_classification, labels) + 1 * criterion_smoothl1( outputs_pos.squeeze(1), var_pos_s[:, 1]) + 1 * criterion_crossentropy2( outputs_len, var_len_labels) loss.backward() optimizer.step() loss_s.append(loss.data) running_loss += loss.data if i_ % print_freq == print_freq - 1: logger.info( 'epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}'.format( n_epoch + prev_epochs, len(loss_s), learning_rate, running_loss / print_freq)) running_loss = 0.0 if len(train_sets) > 1: train_set.close_candidate_tsvs() curr_epoch = n_epoch + prev_epochs if curr_epoch % save_freq == 0: torch.save( { "state_dict": net.state_dict(), "tag": tag, "epoch": curr_epoch, "coverage_thr": coverage_thr, }, '{}/models/checkpoint_{}_epoch{}.pth'.format( out_dir, tag, curr_epoch)) if validation_candidates_tsv: test(net, curr_epoch, validation_loader, use_cuda) if curr_epoch % lr_drop_epochs == 0: learning_rate *= lr_drop_ratio optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) logger.info('Finished Training') if len(train_sets) == 1: train_sets[0].close_candidate_tsvs() curr_epoch = n_epoch + prev_epochs torch.save( { "state_dict": net.state_dict(), "tag": tag, "epoch": curr_epoch, "coverage_thr": coverage_thr }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) if validation_candidates_tsv: test(net, curr_epoch, validation_loader, use_cuda) logger.info("Total Epochs: {}".format(curr_epoch)) return '{}/models/checkpoint_{}_epoch{}.pth'.format( out_dir, tag, curr_epoch)