def train(train_iter, test_iter, net, feature_params, loss, device, num_epochs, file_name): net = net.to(device) print("training on ", device) batch_count = 0 best_test_acc = 0 lr = 0.001 optimizer = Ranger([{ 'params': feature_params }, { 'params': net.fc.parameters(), 'lr': lr * 10 }], lr=lr, weight_decay=0.0001) # optimizer = optim.SGD([{'params': feature_params}, # {'params': net.fc.parameters(), 'lr': lr * 10}], # lr=lr, weight_decay=0.001) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLr(optimizer, T_max=5, eta_min=4e-08) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) for epoch in range(1, num_epochs + 1): train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time() scheduler.step() for X, y in train_iter: X = X.to(device) y = y.to(device) y_hat = net(X) l = loss(y_hat, y) optimizer.zero_grad() l.backward() optimizer.step() train_l_sum += l.cpu().item() train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() n += y.shape[0] batch_count += 1 test_acc = evaluate_accuracy(test_iter, net) print( 'epoch %d, loss %.5f, train_acc %.5f, val_acc %.5f, time %.1f sec' % (epoch, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start)) if test_acc > best_test_acc: print('find best! save at model/%s/best.pth' % file_name) best_test_acc = test_acc torch.save(net.state_dict(), './model/%s/best.pth' % file_name) with open('./result/%s.txt' % file_name, 'a') as acc_file: acc_file.write('Epoch: %2d, acc: %.8f\n' % (epoch, test_acc)) if epoch % 10 == 0: torch.save(net.state_dict(), './model/%s/checkpoint_%d.pth' % (file_name, epoch))
def train_discriminator_adv(dis, dis_option_encoder, state_encoder, reason_decoder, dis_optimizer, epochs, max_length = 25): """ Training the discriminator on real_data_samples (positive) and generated samples from generator (negative). Samples are drawn d_steps times, and the discriminator is trained for epochs epochs. """ # generating a small validation set before training (using oracle and generator) dis_optimizer = Ranger( itertools.chain(dis_option_encoder.parameters(), dis.parameters()), lr=1e-3, weight_decay=1e-4) for epoch in tqdm(range(epochs)): sys.stdout.flush() total_loss = 0 total_acc = 0 count = 0 for train_data in train_loader_adv: num, statement, reason, label, options = train_data batch_size, input_size = statement.size() state_input_lengths = [len(x) for x in statement] falselabel = [[0,1,2] for l in label] for i in range(len(label)): falselabel[i].remove(label[i]) choice_label = [random.choice(ls) for ls in falselabel] input_lengths = [len(x) for x in statement] encoder_outputs, hidden = state_encoder(statement, input_lengths) decoder_input = torch.tensor([[word_encoder.encode("<sos>")] for i in range(batch_size)]).type( torch.cuda.LongTensor) encoder_outputs = encoder_outputs.permute(1, 0, 2) # -> (T*B*H) decoder_outputs = [[] for i in range(batch_size)] eos_idx = word_encoder.encode("<eos>") for di in range(max_length): output, hidden = reason_decoder(decoder_input, hidden, encoder_outputs) topv, topi = output.topk(1) all_end = True for i in range(batch_size): # print(topi.squeeze()[i].item()) if topi.squeeze()[i].item() != eos_idx: decoder_outputs[i].append(word_encoder.decode(np.array([topi.squeeze()[i].item()]))) all_end = False if all_end: break decoder_input = topi.squeeze().detach() decoder_outputs = [" ".join(output) for output in decoder_outputs] for i in range(batch_size): options[choice_label[i]][i][1] = decoder_outputs[i] option_input_lengths = [[len(x[0]) + len(x[1]) for x in option] for option in options] option_lens = [max(len(option[0])+len(option[1])+1 for option in options[i]) for i in range(3)] options = [[word_encoder.encode("<sep>".join([x[0],x[1]])) for x in option] for option in options] options = [sequence.pad_sequences(option, maxlen=option_len, padding='post') for option, option_len in zip(options, option_lens) ] options = [torch.tensor(option).type(torch.cuda.LongTensor) for option in options] option_hiddens = [] for i in range(3): encoder_outputs, option_hidden = dis_option_encoder(options[i], option_input_lengths[i]) option_hiddens.append(option_hidden) dis_optimizer.zero_grad() out = dis(batch_size, option_hiddens) loss_fn = nn.CrossEntropyLoss() label = torch.tensor(label).type(torch.cuda.LongTensor) loss = loss_fn(out, label) loss.backward() dis_optimizer.step() total_loss += loss.data.item() total_acc += (out.argmax(1) == label).sum().item() sys.stdout.flush() count += 1 print('\n average_loss = %.4f, train_acc = %.4f' % ( total_loss / (count * BATCH_SIZE), total_acc / (count * BATCH_SIZE)))
def train_generator_PG(dis, dis_option_encoder, state_encoder, reason_decoder, gen_optimizer, epochs, beta = 15, max_length = 25): """ The generator is trained using policy gradients, using the reward from the discriminator. Training is done for num_batches batches. """ gen_optimizer = Ranger(itertools.chain(state_encoder.parameters(), reason_decoder.parameters()), lr=1e-3, weight_decay=1e-4) for epoch in tqdm(range(epochs)): sys.stdout.flush() total_mle_loss = 0 total_label_loss = 0 total_acc = 0 count = 0 for train_data in train_loader_adv: loss = 0 gen_optimizer.zero_grad() mle_loss = 0 label_loss = 0 num, statement, reason, label, options = train_data reason = torch.tensor(reason).type(torch.cuda.LongTensor) batch_size, input_size = statement.size() state_input_lengths = [len(x) for x in statement] falselabel = [[0, 1, 2] for l in label] for i in range(len(label)): falselabel[i].remove(label[i]) choice_label = [random.choice(ls) for ls in falselabel] input_lengths = [len(x) for x in statement] encoder_outputs, hidden = state_encoder(statement, input_lengths) decoder_input = torch.tensor([[word_encoder.encode("<sos>")] for i in range(batch_size)]).type( torch.cuda.LongTensor) encoder_outputs = encoder_outputs.permute(1, 0, 2) # -> (T*B*H) target_length = reason.size()[1] criterion = nn.NLLLoss() decoder_outputs = [[] for i in range(batch_size)] eos_idx = word_encoder.encode("<eos>") reason_permute = reason.permute(1, 0) for di in range(max_length): output, hidden = reason_decoder(decoder_input, hidden, encoder_outputs) if di < target_length: mle_loss += criterion(output, reason_permute[di]) topv, topi = output.topk(1) all_end = True for i in range(batch_size): # print(topi.squeeze()[i].item()) if topi.squeeze()[i].item() != eos_idx: decoder_outputs[i].append(word_encoder.decode(np.array([topi.squeeze()[i].item()]))) all_end = False if all_end: break decoder_input = topi.squeeze().detach() origin_decoder_outputs = decoder_outputs decoder_outputs = [" ".join(output) for output in decoder_outputs] correct_option = ["<sep>".join(options[label[i]][i]) for i in range(batch_size)] correct_option = [word_encoder.encode(option) for option in correct_option] for i in range(batch_size): options[choice_label[i]][i][1] = decoder_outputs[i] option_input_lengths = [[len(x) for x in option] for option in options] option_lens = [max(len(option[0]) + len(option[1]) + 1 for option in options[i]) for i in range(3)] options = [[word_encoder.encode("<sep>".join([x[0], x[1]])) for x in option] for option in options] options = [sequence.pad_sequences(option, maxlen=option_len, padding='post') for option, option_len in zip(options, option_lens)] options = [torch.tensor(option).type(torch.cuda.LongTensor) for option in options] option_hiddens = [] for i in range(3): encoder_outputs, option_hidden = dis_option_encoder(options[i], option_input_lengths[i]) option_hiddens.append(option_hidden) out = dis(batch_size, option_hiddens) for i in range(batch_size): for j in range(min(len(origin_decoder_outputs[i]),correct_option[i].size()[0])): false_one_hot = torch.zeros(1,3) false_one_hot[0][falselabel[i]] = 1 false_one_hot = false_one_hot.type(torch.cuda.FloatTensor) label_loss -= beta * out[i].mul(false_one_hot).sum() total_label_loss += label_loss total_mle_loss += mle_loss loss = label_loss + mle_loss loss.backward() gen_optimizer.step() count += 1 total_mle_loss = total_mle_loss / (count * BATCH_SIZE) total_label_loss = total_label_loss / (count * BATCH_SIZE) print('\n average_train_NLL = ', total_mle_loss,' the label loss = ', total_label_loss)
def train_alphaBert(DS_model, dloader, lr=1e-4, epoch=10, log_interval=20, lkahead=False): global checkpoint_file DS_model.to(device) # model_optimizer = optim.Adam(DS_model.parameters(), lr=lr) model_optimizer = Ranger(DS_model.parameters(), lr=lr) DS_model = torch.nn.DataParallel(DS_model) DS_model.train() # if lkahead: # print('using Lookahead') # model_optimizer = lookahead_pytorch.Lookahead(model_optimizer, la_steps=5, la_alpha=0.5) # model_optimizer = Ranger(DS_model.parameters(), lr=4e-3, alpha=0.5, k=5) # criterion = nn.MSELoss().to(device) # criterion = alphabert_loss_v02.Alphabert_loss(device=device) criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device) iteration = 0 total_loss = [] for ep in range(epoch): DS_model.train() t0 = time.time() # step_loss = 0 epoch_loss = 0 epoch_cases = 0 for batch_idx, sample in enumerate(dloader): model_optimizer.zero_grad() loss = 0 src = sample['src_token'] trg = sample['trg'] att_mask = sample['mask_padding'] origin_len = sample['origin_seq_length'] bs = len(src) src = src.float().to(device) trg = trg.long().to(device) att_mask = att_mask.float().to(device) origin_len = origin_len.to(device) pred_prop, = DS_model(input_ids=src, attention_mask=att_mask, out='finehead') trg_view = trg.view(-1).contiguous() trg_mask0 = trg_view == 0 trg_mask1 = trg_view == 1 loss = criterion(pred_prop, trg_view) # try: # loss0 = criterion(pred_prop[trg_mask0],trg_view[trg_mask0]) # loss1 = criterion(pred_prop[trg_mask1],trg_view[trg_mask1]) # # loss += 0.2*loss0+0.8*loss1 # except: # loss = criterion(pred_prop,trg.view(-1).contiguous()) loss.backward() model_optimizer.step() with torch.no_grad(): epoch_loss += loss.item() * bs epoch_cases += bs if iteration % log_interval == 0: # step_loss.backward() # model_optimizer.step() # print('+++ update +++') print( 'Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.format( ep, batch_idx * batch_size, 100. * batch_idx / len(dloader), (time.time() - t0) * len(dloader) / (60 * (batch_idx + 1)), loss.item())) # print(0,st_target) # step_loss = 0 if iteration % 400 == 0: save_checkpoint(checkpoint_file, 'd2s_total.pth', DS_model, model_optimizer, parallel=parallel) print( tokenize_alphabets.convert_idx2str(src[0][:origin_len[0]])) iteration += 1 if ep % 1 == 0: save_checkpoint(checkpoint_file, 'd2s_total.pth', DS_model, model_optimizer, parallel=parallel) # test_alphaBert(DS_model,D2S_valloader, # is_clean_up=True, ep=ep,train=True) print('======= epoch:%i ========' % ep) # print('total loss: {:.4f}'.format(total_loss/len(dloader))) print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0)) # total_loss.append(epoch_loss) total_loss.append(float(epoch_loss / epoch_cases)) pd_total_loss = pd.DataFrame(total_loss) pd_total_loss.to_csv('./iou_pic/total_loss_finetune.csv', sep=',') print(total_loss)
def train(args): # get configs epochs = args.epoch dim = args.dim lr = args.lr weight_decay = args.l2 head_num = args.head_num device = args.device act = args.act fusion = args.fusion beta = args.beta alpha = args.alpha use_self = args.use_self agg = args.agg model = DATE(leaf_num,importer_size,item_size,\ dim,head_num,\ fusion_type=fusion,act=act,device=device,\ use_self=use_self,agg_type=agg, ).to(device) model = nn.DataParallel(model,device_ids=[0,1]) # initialize parameters for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) # optimizer & loss optimizer = Ranger(model.parameters(), weight_decay=weight_decay,lr=lr) cls_loss_func = nn.BCELoss() reg_loss_func = nn.MSELoss() # save best model global_best_score = 0 model_state = None # early stop settings stop_rounds = 3 no_improvement = 0 current_score = None for epoch in range(epochs): for step, (batch_feature,batch_user,batch_item,batch_cls,batch_reg) in enumerate(train_loader): model.train() # prep to train model batch_feature,batch_user,batch_item,batch_cls,batch_reg = \ batch_feature.to(device), batch_user.to(device), batch_item.to(device),\ batch_cls.to(device), batch_reg.to(device) batch_cls,batch_reg = batch_cls.view(-1,1), batch_reg.view(-1,1) # model output classification_output, regression_output, hidden_vector = model(batch_feature,batch_user,batch_item) # FGSM attack adv_vector = fgsm_attack(model,cls_loss_func,hidden_vector,batch_cls,0.01) adv_output = model.module.pred_from_hidden(adv_vector) # calculate loss adv_loss_func = nn.BCELoss(weight=batch_cls) adv_loss = beta * adv_loss_func(adv_output,batch_cls) cls_loss = cls_loss_func(classification_output,batch_cls) revenue_loss = alpha * reg_loss_func(regression_output, batch_reg) loss = cls_loss + revenue_loss + adv_loss optimizer.zero_grad() loss.backward() optimizer.step() if (step+1) % 1000 ==0: print("CLS loss:%.4f, REG loss:%.4f, ADV loss:%.4f, Loss:%.4f"\ %(cls_loss.item(),revenue_loss.item(),adv_loss.item(),loss.item())) # evaluate model.eval() print("Validate at epoch %s"%(epoch+1)) y_prob, val_loss = model.module.eval_on_batch(valid_loader) y_pred_tensor = torch.tensor(y_prob).float().to(device) best_threshold, val_score, roc = torch_threshold(y_prob,xgb_validy) overall_f1, auc, precisions, recalls, f1s, revenues = metrics(y_prob,xgb_validy,revenue_valid) select_best = np.mean(f1s) print("Over-all F1:%.4f, AUC:%.4f, F1-top:%.4f" % (overall_f1, auc, select_best) ) print("Evaluate at epoch %s"%(epoch+1)) y_prob, val_loss = model.module.eval_on_batch(test_loader) y_pred_tensor = torch.tensor(y_prob).float().to(device) overall_f1, auc, precisions, recalls, f1s, revenues = metrics(y_prob,xgb_testy,revenue_test,best_thresh=best_threshold) print("Over-all F1:%.4f, AUC:%.4f, F1-top:%.4f" %(overall_f1, auc, np.mean(f1s)) ) # save best model if select_best > global_best_score: global_best_score = select_best torch.save(model,model_path) # early stopping if current_score == None: current_score = select_best continue if select_best < current_score: current_score = select_best no_improvement += 1 if no_improvement >= stop_rounds: print("Early stopping...") break if select_best > current_score: no_improvement = 0 current_score = None
class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] image_tmp, _ = iter(self.train_loader).next() self.image_size = (image_tmp.shape[2], image_tmp.shape[3]) if 'MNIST' in config.dataset_name or config.dataset_name == 'CIFAR': self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) elif config.dataset_name == 'ImageNet': # the ImageNet cannot be sampled, otherwise this part will be wrong. self.num_train = 100000 #len(train_dataset) in data_loader.py, wrong: len(self.train_loader) self.num_valid = 10000 #len(self.valid_loader) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) image_tmp, _ = iter(self.test_loader).next() self.image_size = (image_tmp.shape[2], image_tmp.shape[3]) # assign numer of channels and classes of images in this dataset, maybe there is more robust way if 'MNIST' in config.dataset_name: self.num_channels = 1 self.num_classes = 10 elif config.dataset_name == 'ImageNet': self.num_channels = 3 self.num_classes = 1000 elif config.dataset_name == 'CIFAR': self.num_channels = 3 self.num_classes = 10 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr self.loss_fun_baseline = config.loss_fun_baseline self.loss_fun_action = config.loss_fun_action self.weight_decay = config.weight_decay # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.best_train_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq if config.use_gpu: self.model_name = 'ram_gpu_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format( config.PBSarray_ID, config.num_glimpses, config.patch_size, config.patch_size, config.hidden_size, config.std, config.dropout) else: self.model_name = 'ram_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format( config.PBSarray_ID, config.num_glimpses, config.patch_size, config.patch_size, config.hidden_size, config.std, config.dropout) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir, exist_ok=True) # configure tensorboard logging if self.use_tensorboard: print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) writer = SummaryWriter(logs_dir=self.logs_dir + self.model_name) # build DRAMBUTD model self.model = RecurrentAttention(self.patch_size, self.num_channels, self.image_size, self.std, self.hidden_size, self.num_classes, config) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # initialize optimizer and scheduler if config.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) elif config.optimizer == 'ReduceLROnPlateau': self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=self.lr_patience, weight_decay=self.weight_decay) elif config.optimizer == 'Adadelta': self.optimizer = optim.Adadelta(self.model.parameters(), weight_decay=self.weight_decay) elif config.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=3e-4, weight_decay=self.weight_decay) elif config.optimizer == 'AdaBound': self.optimizer = adabound.AdaBound(self.model.parameters(), lr=3e-4, final_lr=0.1, weight_decay=self.weight_decay) elif config.optimizer == 'Ranger': self.optimizer = Ranger(self.model.parameters(), weight_decay=self.weight_decay) def reset(self, x, SM): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) # h_t2, l_t, SM_local_smooth = self.model.initialize(x, SM) # initialize hidden state 1 as 0 vector to avoid the directly classification from context h_t1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) cell_state1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) cell_state2 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) return h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus # self.scheduler.step(valid_loss) is_best_valid = valid_acc > self.best_valid_acc is_best_train = train_acc > self.best_train_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best_train: msg1 += " [*]" if is_best_valid: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best_valid: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.best_train_acc = max(train_acc, self.best_train_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, 'best_train_acc': self.best_train_acc, }, is_best_valid) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x_raw, y) in enumerate(self.train_loader): # if self.use_gpu: x_raw, y = x_raw.cuda(), y.cuda() # detach images and their saliency maps x = x_raw[:, 0, ...].unsqueeze(1) SM = x_raw[:, 1, ...].unsqueeze(1) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] if self.loss_fun_baseline == 'cross_entropy': # cross_entroy_loss need a long, batch x 1 tensor as target but R # also need to be subtracted by the baseline whose size is N x num_glimpse R = (predicted.detach() == y).long() # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) R = R.float().unsqueeze(1).repeat(1, self.num_glimpses) else: R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) # loss_action = F.nll_loss(log_probas, y) # loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store #losses.update(loss.data[0], x.size()[0]) #accs.update(acc.data[0], x.size()[0]) losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data.item(), acc.data.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) sio.savemat(self.plot_dir + "data_train_{}.mat".format(epoch + 1), mdict={ 'location': locs, 'patch': imgs }) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i writer.add_scalar('Loss/train', losses, iteration) writer.add_scalar('Accuracy/train', accs, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x_raw, y) in enumerate(self.valid_loader): if self.use_gpu: x_raw, y = x_raw.cuda(), y.cuda() # detach images and their saliency maps x = x_raw[:, 0, ...].unsqueeze(1) SM = x_raw[:, 1, ...].unsqueeze(1) # duplicate M times x = x.repeat(self.M, 1, 1, 1) SM = SM.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) # store log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] if self.loss_fun_baseline == 'cross_entropy': # cross_entroy_loss need a long, batch x 1 tensor as target but R # also need to be subtracted by the baseline whose size is N x num_glimpse R = (predicted.detach() == y).long() # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) R = R.float().unsqueeze(1).repeat(1, self.num_glimpses) else: R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) # compute losses for differentiable modules # loss_action = F.nll_loss(log_probas, y) # loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i writer.add_scalar('Accuracy/valid', accs, iteration) writer.add_scalar('Loss/valid', losses, iteration) return losses.avg, accs.avg def choose_loss_fun(self, log_probas, y, baselines, R): """ use disctionary to save function handle replacement of swith-case be careful of the argument data type and shape!!! """ loss_fun_pool = { 'mse': F.mse_loss, 'l1': F.l1_loss, 'nll': F.nll_loss, 'smooth_l1': F.smooth_l1_loss, 'kl_div': F.kl_div, 'cross_entropy': F.cross_entropy } return loss_fun_pool[self.loss_fun_action]( log_probas, y), loss_fun_pool[self.loss_fun_baseline](baselines, R) def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x, volatile=True), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # save images and glimpse location locs = [] imgs = [] imgs.append(x[0:9]) for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() # dump test data if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump(imgs, open(self.plot_dir + "g_test.p", "wb")) pickle.dump(locs, open(self.plot_dir + "l_test.p", "wb")) sio.savemat(self.plot_dir + "test_transient.mat", mdict={'location': locs}) perc = (100. * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
############################ 数据出口######################################### ############ 请将你的output数据重新转换成 bscwh格式进行下一步################# #output = output.unsqueeze(2) ############################################################################### #loss1 = criterionmse(output, targets) #loss2 = criterionmae(output, targets) print(targets.shape) loss = criterionbmse(output, targets) #loss3 = criteriongdl(output,target) #loss = loss1+loss2 optimizer.zero_grad() loss.backward() losses.update(loss.item()) nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() #t_rmse = rmse_loss(output, targets) #rmse.update(t_rmse.item()) #output_np = np.clip(output.detach().cpu().numpy(), 0, 1) #target_np = np.clip(targets.detach().cpu().numpy(), 0, 1) logging.info('[{0}][{1}][{2}]\t' 'lr: {lr:.5f}\t' 'loss: {loss.val:.6f} ({loss.avg:.6f})'.format( epoch, headid, ind, lr=optimizer.param_groups[-1]['lr'], loss=losses)) writer.add_scalars("trainloss", {"train": losses.val}, step) step += 1
def train(): if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) dataset = COCODetection(image_path=cfg.dataset.train_images, info_file=cfg.dataset.train_info, transform=SSDAugmentation(MEANS)) if args.validation_epoch > 0: setup_eval() val_dataset = COCODetection(image_path=cfg.dataset.valid_images, info_file=cfg.dataset.valid_info, transform=BaseTransform(MEANS)) # Parallel wraps the underlying module, but when saving and loading we don't want that yolact_net = Yolact() net = yolact_net net.train() print('\n--- Generator created! ---') # NOTE # I maunally set the original image size and seg size as 138 # might change in the future, for example 550 if cfg.pred_seg: dis_size = 138 dis_net = Discriminator_Wgan(i_size = dis_size, s_size = dis_size) # Change the initialization inside the dis_net class inside # set the dis net's initial parameter values # dis_net.apply(gan_init) dis_net.train() print('--- Discriminator created! ---\n') if args.log: log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()), overwrite=(args.resume is None), log_gpu_stats=args.log_gpu) # I don't use the timer during training (I use a different timing method). # Apparently there's a race condition with multiple GPUs, so disable it just to be safe. timer.disable_all() # Both of these can set args.resume to None, so do them before the check if args.resume == 'interrupt': args.resume = SavePath.get_interrupt(args.save_folder) elif args.resume == 'latest': args.resume = SavePath.get_latest(args.save_folder, cfg.name) if args.resume is not None: print('Resuming training, loading {}...'.format(args.resume)) yolact_net.load_weights(args.resume) if args.start_iter == -1: args.start_iter = SavePath.from_str(args.resume).iteration else: print('Initializing weights...') yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) # optimizer_gen = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, # weight_decay=args.decay) # if cfg.pred_seg: # optimizer_dis = optim.SGD(dis_net.parameters(), lr=cfg.dis_lr, momentum=args.momentum, # weight_decay=args.decay) # schedule_dis = ReduceLROnPlateau(optimizer_dis, mode = 'min', patience=6, min_lr=1E-6) # NOTE: Using the Ranger Optimizer for the generator optimizer_gen = Ranger(net.parameters(), lr = args.lr, weight_decay=args.decay) # optimizer_gen = optim.RMSprop(net.parameters(), lr = args.lr) # FIXME: Might need to modify the lr in the optimizer carefually # check this # def make_D_optimizer(cfg, model): # params = [] # for key, value in model.named_parameters(): # if not value.requires_grad: # continue # lr = cfg.SOLVER.BASE_LR/5.0 # weight_decay = cfg.SOLVER.WEIGHT_DECAY # if "bias" in key: # lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR/5.0 # weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS # params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] # optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) # return optimizer if cfg.pred_seg: optimizer_dis = optim.SGD(dis_net.parameters(), lr=cfg.dis_lr) # optimizer_dis = optim.RMSprop(dis_net.parameters(), lr = cfg.dis_lr) schedule_dis = ReduceLROnPlateau(optimizer_dis, mode = 'min', patience=6, min_lr=1E-6) criterion = MultiBoxLoss(num_classes=cfg.num_classes, pos_threshold=cfg.positive_iou_threshold, neg_threshold=cfg.negative_iou_threshold, negpos_ratio=cfg.ohem_negpos_ratio, pred_seg=cfg.pred_seg) # criterion_dis = nn.BCELoss() # Take the advice from WGAN criterion_dis = DiscriminatorLoss_Maskrcnn() criterion_gen = GeneratorLoss_Maskrcnn() if args.batch_alloc is not None: # e.g. args.batch_alloc: 24,24 args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')] if sum(args.batch_alloc) != args.batch_size: print('Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size)) exit(-1) net = CustomDataParallel(NetLoss(net, criterion, pred_seg=cfg.pred_seg)) if args.cuda: net = net.cuda() # NOTE if cfg.pred_seg: dis_net = nn.DataParallel(dis_net) dis_net = dis_net.cuda() # Initialize everything if not cfg.freeze_bn: yolact_net.freeze_bn() # Freeze bn so we don't kill our means yolact_net(torch.zeros(1, 3, cfg.max_size, cfg.max_size).cuda()) if not cfg.freeze_bn: yolact_net.freeze_bn(True) # loss counters loc_loss = 0 conf_loss = 0 iteration = max(args.start_iter, 0) last_time = time.time() epoch_size = len(dataset) // args.batch_size num_epochs = math.ceil(cfg.max_iter / epoch_size) # Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index step_index = 0 data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) # NOTE val_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers*2, shuffle=True, collate_fn=detection_collate, pin_memory=True) save_path = lambda epoch, iteration: SavePath(cfg.name, epoch, iteration).get_path(root=args.save_folder) time_avg = MovingAverage() global loss_types # Forms the print order # TODO: global command can modify global variable inside of the function. loss_avgs = { k: MovingAverage(100) for k in loss_types } # NOTE # Enable AMP amp_enable = cfg.amp scaler = torch.cuda.amp.GradScaler(enabled=amp_enable) print('Begin training!') print() # try-except so you can use ctrl+c to save early and stop training try: for epoch in range(num_epochs): # Resume from start_iter if (epoch+1)*epoch_size < iteration: continue for datum in data_loader: # Stop if we've reached an epoch if we're resuming from start_iter if iteration == (epoch+1)*epoch_size: break # Stop at the configured number of iterations even if mid-epoch if iteration == cfg.max_iter: break # Change a config setting if we've reached the specified iteration changed = False for change in cfg.delayed_settings: if iteration >= change[0]: changed = True cfg.replace(change[1]) # Reset the loss averages because things might have changed for avg in loss_avgs: avg.reset() # If a config setting was changed, remove it from the list so we don't keep checking if changed: cfg.delayed_settings = [x for x in cfg.delayed_settings if x[0] > iteration] # Warm up by linearly interpolating the learning rate from some smaller value if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until: set_lr(optimizer_gen, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init) # Adjust the learning rate at the given iterations, but also if we resume from past that iteration while step_index < len(cfg.lr_steps) and iteration >= cfg.lr_steps[step_index]: step_index += 1 set_lr(optimizer_gen, args.lr * (args.gamma ** step_index)) # NOTE if cfg.pred_seg: # ====== GAN Train ====== # train the gen and dis in different iteration # it_alter_period = iteration % (cfg.gen_iter + cfg.dis_iter) # FIXME: # present_time = time.time() for _ in range(cfg.dis_iter): # freeze_pretrain(yolact_net, freeze=False) # freeze_pretrain(net, freeze=False) # freeze_pretrain(dis_net, freeze=False) # if it_alter_period == 0: # print('--- Generator freeze ---') # print('--- Discriminator training ---') if cfg.amp: with torch.cuda.amp.autocast(): # ----- Discriminator part ----- # seg_list is the prediction mask # can be regarded as generated images from YOLACT # pred_list is the prediction label # seg_list dim: list of (138,138,instances) # pred_list dim: list of (instances) losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) # input image size is [b, 3, 550, 550] # downsample to [b, 3, seg_h, seg_w] image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) # Because in the discriminator training, we do not # want the gradient flow back to the generator part # we detach seg_clas (mask_clas come the data, does not have grad) output_pred = dis_net(img = image.detach(), seg = seg_clas.detach()) output_grou = dis_net(img = image.detach(), seg = mask_clas.detach()) # p = elem_mul_p.squeeze().permute(1,2,0).cpu().detach().numpy() # g = elem_mul_g.squeeze().permute(1,2,0).cpu().detach().numpy() # image = image.squeeze().permute(1,2,0).cpu().detach().numpy() # from PIL import Image # seg_PIL = Image.fromarray(p, 'RGB') # mask_PIL = Image.fromarray(g, 'RGB') # seg_PIL.save('mul_seg.png') # mask_PIL.save('mul_mask.png') # raise RuntimeError # from matplotlib import pyplot as plt # fig, (ax1, ax2) = plt.subplots(1,2) # ax1.imshow(mask_show) # ax2.imshow(seg_show) # plt.show(block=False) # plt.pause(2) # plt.close() # if iteration % (cfg.gen_iter + cfg.dis_iter) == 0: # print(f'Probability of fake is fake: {output_pred.mean().item():.2f}') # print(f'Probability of real is real: {output_grou.mean().item():.2f}') # 0 for Fake/Generated # 1 for True/Ground Truth # fake_label = torch.zeros(b) # real_label = torch.ones(b) # Advice of practical implementation # from https://arxiv.org/abs/1611.08408 # loss_pred = -criterion_dis(output_pred,target=real_label) # loss_pred = criterion_dis(output_pred,target=fake_label) # loss_grou = criterion_dis(output_grou,target=real_label) # loss_dis = loss_pred + loss_grou # Wasserstein Distance (Earth-Mover) loss_dis = criterion_dis(input=output_grou,target=output_pred) # Backprop the discriminator # Scales loss. Calls backward() on scaled loss to create scaled gradients. scaler.scale(loss_dis).backward() scaler.step(optimizer_dis) scaler.update() optimizer_dis.zero_grad() # clip the updated parameters _ = [par.data.clamp_(-cfg.clip_value, cfg.clip_value) for par in dis_net.parameters()] # ----- Generator part ----- # freeze_pretrain(yolact_net, freeze=False) # freeze_pretrain(net, freeze=False) # freeze_pretrain(dis_net, freeze=False) # if it_alter_period == (cfg.dis_iter+1): # print('--- Generator training ---') # print('--- Discriminator freeze ---') # FIXME: # print(f'dis time pass: {time.time()-present_time:.2f}') # FIXME: # present_time = time.time() with torch.cuda.amp.autocast(): losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) # Perform forward pass of all-fake batch through D # NOTE this seg_clas CANNOT detach, in order to flow the # gradient back to the generator # output = dis_net(img = image, seg = seg_clas) # Since the log(1-D(G(x))) not provide sufficient gradients # We want log(D(G(x)) instead, this can be achieve by # use the real_label as target. # This step is crucial for the information of discriminator # to go into the generator. # Calculate G's loss based on this output # real_label = torch.ones(b) # loss_gen = criterion_dis(output,target=real_label) # GAN MaskRCNN output_pred = dis_net(img = image, seg = seg_clas) output_grou = dis_net(img = image, seg = mask_clas) # Advice from WGAN # loss_gen = -torch.mean(output) loss_gen = criterion_gen(input=output_grou,target=output_pred) # since the dis is already freeze, the gradients will only # record the YOLACT losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) loss += loss_gen # Generator backprop scaler.scale(loss).backward() scaler.step(optimizer_gen) scaler.update() optimizer_gen.zero_grad() # FIXME: # print(f'gen time pass: {time.time()-present_time:.2f}') # print('GAN part over') else: losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) output_pred = dis_net(img = image.detach(), seg = seg_clas.detach()) output_grou = dis_net(img = image.detach(), seg = mask_clas.detach()) loss_dis = criterion_dis(input=output_grou,target=output_pred) loss_dis.backward() optimizer_dis.step() optimizer_dis.zero_grad() _ = [par.data.clamp_(-cfg.clip_value, cfg.clip_value) for par in dis_net.parameters()] # ----- Generator part ----- # FIXME: # print(f'dis time pass: {time.time()-present_time:.2f}') # FIXME: # present_time = time.time() losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) # GAN MaskRCNN output_pred = dis_net(img = image, seg = seg_clas) output_grou = dis_net(img = image, seg = mask_clas) loss_gen = criterion_gen(input=output_grou,target=output_pred) # since the dis is already freeze, the gradients will only # record the YOLACT losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) loss += loss_gen loss.backward() # Do this to free up vram even if loss is not finite optimizer_gen.zero_grad() if torch.isfinite(loss).item(): # since the optimizer_gen is for YOLACT only # only the gen will be updated optimizer_gen.step() # FIXME: # print(f'gen time pass: {time.time()-present_time:.2f}') # print('GAN part over') else: # ====== Normal YOLACT Train ====== # Zero the grad to get ready to compute gradients optimizer_gen.zero_grad() # Forward Pass + Compute loss at the same time (see CustomDataParallel and NetLoss) losses = net(datum) losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) # no_inf_mean removes some components from the loss, so make sure to backward through all of it # all_loss = sum([v.mean() for v in losses.values()]) # Backprop loss.backward() # Do this to free up vram even if loss is not finite if torch.isfinite(loss).item(): optimizer_gen.step() # Add the loss to the moving average for bookkeeping _ = [loss_avgs[k].add(losses[k].item()) for k in losses] # for k in losses: # loss_avgs[k].add(losses[k].item()) cur_time = time.time() elapsed = cur_time - last_time last_time = cur_time # Exclude graph setup from the timing information if iteration != args.start_iter: time_avg.add(elapsed) if iteration % 10 == 0: eta_str = str(datetime.timedelta(seconds=(cfg.max_iter-iteration) * time_avg.get_avg())).split('.')[0] total = sum([loss_avgs[k].get_avg() for k in losses]) loss_labels = sum([[k, loss_avgs[k].get_avg()] for k in loss_types if k in losses], []) if cfg.pred_seg: print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) + ' T: %.3f || ETA: %s || timer: %.3f') % tuple([epoch, iteration] + loss_labels + [total, eta_str, elapsed]), flush=True) # print(f'Generator loss: {loss_gen:.2f} | Discriminator loss: {loss_dis:.2f}') # Loss Key: # - B: Box Localization Loss # - C: Class Confidence Loss # - M: Mask Loss # - P: Prototype Loss # - D: Coefficient Diversity Loss # - E: Class Existence Loss # - S: Semantic Segmentation Loss # - T: Total loss if args.log: precision = 5 loss_info = {k: round(losses[k].item(), precision) for k in losses} loss_info['T'] = round(loss.item(), precision) if args.log_gpu: log.log_gpu_stats = (iteration % 10 == 0) # nvidia-smi is sloooow log.log('train', loss=loss_info, epoch=epoch, iter=iteration, lr=round(cur_lr, 10), elapsed=elapsed) log.log_gpu_stats = args.log_gpu iteration += 1 if iteration % args.save_interval == 0 and iteration != args.start_iter: if args.keep_latest: latest = SavePath.get_latest(args.save_folder, cfg.name) print('Saving state, iter:', iteration) yolact_net.save_weights(save_path(epoch, iteration)) if args.keep_latest and latest is not None: if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval: print('Deleting old save...') os.remove(latest) # This is done per epoch if args.validation_epoch > 0: # NOTE: Validation loss # if cfg.pred_seg: # net.eval() # dis_net.eval() # cfg.gan_eval = True # with torch.no_grad(): # for datum in tqdm(val_loader, desc='GAN Validation'): # losses, seg_list, pred_list = net(datum) # losses, seg_list, pred_list = net(datum) # # TODO: warp below as a function # seg_list = [v.permute(2,1,0).contiguous() for v in seg_list] # b = len(seg_list) # batch size # _, seg_h, seg_w = seg_list[0].size() # seg_clas = torch.zeros(b, cfg.num_classes-1, seg_h, seg_w) # mask_clas = torch.zeros(b, cfg.num_classes-1, seg_h, seg_w) # target_list = [target for target in datum[1][0]] # mask_list = [interpolate(mask.unsqueeze(0), size = (seg_h,seg_w),mode='bilinear', \ # align_corners=False).squeeze() for mask in datum[1][1]] # for idx in range(b): # for i, (pred, i_target) in enumerate(zip(pred_list[idx], target_list[idx])): # seg_clas[idx, pred, ...] += seg_list[idx][i,...] # mask_clas[idx, i_target[-1].long(), ...] += mask_list[idx][i,...] # seg_clas = torch.clamp(seg_clas, 0, 1) # image = interpolate(torch.stack(datum[0]), size = (seg_h,seg_w), # mode='bilinear',align_corners=False) # real_label = torch.ones(b) # output_pred = dis_net(img = image, seg = seg_clas) # output_grou = dis_net(img = image, seg = mask_clas) # loss_pred = -criterion_dis(output_pred,target=real_label) # loss_grou = criterion_dis(output_grou,target=real_label) # loss_dis = loss_pred + loss_grou # losses = { k: (v).mean() for k,v in losses.items() } # loss = sum([losses[k] for k in losses]) # val_loss = loss - cfg.lambda_dis*loss_dis # schedule_dis.step(loss_dis) # lr = [group['lr'] for group in optimizer_dis.param_groups] # print(f'Discriminator lr: {lr[0]}') # net.train() if epoch % args.validation_epoch == 0 and epoch > 0: cfg.gan_eval = False dis_net.eval() compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) # Compute validation mAP after training is finished compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) except KeyboardInterrupt: if args.interrupt: print('Stopping early. Saving network...') # Delete previous copy of the interrupted network so we don't spam the weights folder SavePath.remove_interrupt(args.save_folder) yolact_net.save_weights(save_path(epoch, repr(iteration) + '_interrupt')) exit() yolact_net.save_weights(save_path(epoch, iteration))
class TOMTrainer: def __init__(self, gen, dis, dataloader_train, dataloader_val, gpu_id, log_freq, save_dir, n_step, optimizer='adam'): if torch.cuda.is_available(): self.device = torch.device('cuda:' + str(gpu_id)) else: self.device = torch.device('cpu') self.gen = gen.to(self.device) self.dis = dis.to(self.device) self.dataloader_train = dataloader_train self.dataloader_val = dataloader_val if optimizer == 'adam': self.optim_g = torch.optim.Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optim_d = torch.optim.Adam(dis.parameters(), lr=1e-4, betas=(0.5, 0.999)) elif optimizer == 'ranger': self.optim_g = Ranger(gen.parameters()) self.optim_d = Ranger(dis.parameters()) self.criterionL1 = nn.L1Loss() self.criterionVGG = VGGLoss() self.criterionAdv = torch.nn.BCELoss() self.log_freq = log_freq self.save_dir = save_dir self.n_step = n_step self.step = 0 print('Generator Parameters:', sum([p.nelement() for p in self.gen.parameters()])) print('Discriminator Parameters:', sum([p.nelement() for p in self.dis.parameters()])) def train(self, epoch): """Iterate 1 epoch over train data and return loss """ return self.iteration(epoch, self.dataloader_train) def val(self, epoch): """Iterate 1 epoch over validation data and return loss """ return self.iteration(epoch, self.dataloader_val, train=False) def iteration(self, epoch, data_loader, train=True): data_iter = tqdm(enumerate(data_loader), desc='epoch: %d' % (epoch), total=len(data_loader), bar_format='{l_bar}{r_bar}') total_loss = 0.0 FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available( ) else torch.FloatTensor for i, _data in data_iter: data = {} for key, value in _data.items(): if not 'name' in key: data[key] = value.to(self.device) # Load data on GPU cloth = data['cloth'] cloth_mask = data['cloth_mask'] person = data['person'] batch_size = person.shape[0] outputs = self.gen(torch.cat([data['feature'], cloth], 1)) # (batch, channel, height, width) rendered_person, composition_mask = torch.split(outputs, 3, 1) rendered_person = torch.tanh(rendered_person) composition_mask = torch.sigmoid(composition_mask) tryon_person = cloth * composition_mask + rendered_person * ( 1 - composition_mask) visuals = [[data['head'], data['shape'], data['pose']], [cloth, cloth_mask * 2 - 1, composition_mask * 2 - 1], [rendered_person, tryon_person, person]] # Adversarial ground truths real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) # Batch size fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # Loss measures generator's ability to fool the discriminator l_l1 = self.criterionL1(tryon_person, person) l_mask = self.criterionL1(composition_mask, cloth_mask) l_vgg = self.criterionVGG(tryon_person, person) metric = measure.compare_ssim(tryon_person, person, multichannel=True) dis_fake = self.dis( torch.cat([data['feature'], cloth, tryon_person], 1)) # Dis forward l_adv = self.criterionAdv(dis_fake, real) loss_g = l_l1 + l_vgg + l_mask + l_adv / batch_size # Loss for discriminator loss_d = ( self.criterionAdv(self.dis(torch.cat([data['feature'], cloth, person],1)), real) +\ self.criterionAdv(self.dis(torch.cat([data['feature'], cloth, tryon_person],1).detach()), fake) )\ / 2 if train: self.optim_g.zero_grad() loss_g.backward() self.optim_g.step() self.optim_d.zero_grad() loss_d.backward() self.optim_d.step() self.step += 1 total_loss = total_loss + loss_g.item() + loss_d.item() post_fix = { 'epoch': epoch, 'iter': i, 'avg_loss': total_loss / (i + 1), 'loss_recon': l_l1.item() + l_vgg.item() + l_mask.item(), 'loss_g': l_adv.item(), 'loss_d': loss_d.item(), 'ssim': metric } if train and i % self.log_freq == 0: data_iter.write(str(post_fix)) board_add_images(visuals, epoch, i, os.path.join(self.save_dir, 'train')) return total_loss / len(data_iter)
def train(model, exp_id): save_model_dir = f'experiment/{exp_id}/save_models' try: os.mkdir(save_model_dir) except Exception as e: print(e) # ==== INIT MODEL================= device = get_device() model.to(device) optimizer = Ranger(model.parameters(), lr=cfg["model_params"]["lr"], weight_decay=0.0001) train_dataloader, gt_dict = load_tune_data() start = time.time() train_result = { 'epoch': [], 'iter': [], 'loss[-k:](avg)': [], 'validation_loss': [], 'time(minute)': [], } def append_train_result(epoch, iter, avg_lossk, validation_loss, run_time): train_result['epoch'].append(epoch) train_result['iter'].append(iter) train_result['loss[-k:](avg)'].append(avg_lossk) train_result['validation_loss'].append(validation_loss) train_result['time(minute)'].append(run_time) k = 1000 lossk = [] torch.set_grad_enabled(True) num_iter = len(train_dataloader) print(num_iter) for epoch in range(cfg['train_params']['epoch']): model.train() torch.set_grad_enabled(True) tr_it = iter(train_dataloader) train_progress_bar = tqdm(range(num_iter)) optimizer.zero_grad() print('epoch:', epoch) for i in train_progress_bar: try: data = next(tr_it) preds, confidences = forward(data, model, device) # convert to world positions world_from_agents = data['world_from_agent'].float().to(device) centroids = data['centroid'].float().to(device) world_from_agents = world_from_agents.unsqueeze(1) # bs * 1 * 3 * 3 world_from_agents = world_from_agents.repeat(1, 3, 1, 1) # bs * 1 * 3 * 3 centroids = centroids.unsqueeze(1).unsqueeze(1) centroids = centroids.repeat(1, 3, 50, 1) preds = transform_ts_points(preds, world_from_agents.clone()) - centroids[:, :, :, :2].clone() # get ground_truth target_availabilities = [] target_positions = [] for track_id, timestamp in zip(data['track_id'], data['timestamp']): key = str(track_id.item()) + str(timestamp.item()) target_positions.append(torch.tensor(gt_dict[key]['coord'])) target_availabilities.append(torch.tensor(gt_dict[key]['avail'])) target_availabilities = torch.stack(target_availabilities).to(device) target_positions = torch.stack(target_positions).to(device) loss = criterion(target_positions, preds, confidences, target_availabilities) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() lossk.append(loss.item()) if len(lossk) > k: lossk.pop(0) train_progress_bar.set_description( f"loss: {loss.item():.4f} loss[-k:](avg): {np.mean(lossk):.4f}") if ((i > 0 and i % cfg['train_params']['checkpoint_steps'] == 0) or i == num_iter-1): # save model per checkpoint torch.save(model.state_dict(), f'{save_model_dir}/epoch{epoch:02d}_iter{i:05d}.pth') append_train_result(epoch, i, np.mean(lossk), -1, (time.time()-start)/60) if ((i > 0 and i % cfg['train_params']['validation_steps'] == 0 ) or i == num_iter-1): validation_loss = validation(model, device) append_train_result(epoch, i, -1, validation_loss, (time.time()-start)/60) model.train() torch.set_grad_enabled(True) except KeyboardInterrupt: torch.save(model.state_dict(), f'{save_model_dir}/interrupt_epoch{epoch:02d}_iter{i:05d}.pth') # save train result results = pd.DataFrame(train_result) results.to_csv( f"experiment/{exp_id}/interrupt_train_result.csv", index=False) print(f"Total training time is {(time.time()-start)/60} mins") print(results) raise KeyboardInterrupt # torch.save(model.state_dict(), f'{save_model_dir}/epoch{epoch:02d}_iter{i:05d}.pth') del tr_it, train_progress_bar # save train result results = pd.DataFrame(train_result) results.to_csv(f"experiment/{exp_id}/train_result.csv", index=False) print(f"Total training time is {(time.time()-start)/60} mins") print(results)
def main(conf): if conf['method'] not in ['baseline', 'LmixLact', 'Lmix']: raise ValueError("method must be baseline, LmixLact or Lmix") # Set random seeds both for pytorch and numpy th.manual_seed(conf['seed']) np.random.seed(conf['seed']) # Create experiment folder and save conf file with the final configuration os.makedirs(conf['exp_dir'], exist_ok=True) conf_path = os.path.join(conf['exp_dir'], 'conf.yml') with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) # Load test set. Be careful about is_wav! test_set = musdb.DB(root=conf['musdb_path'], subsets=["test"], is_wav=True) # Randomly choose the indexes of sentences to save. ex_save_dir = os.path.join(conf['exp_dir'], 'examples/') if conf['n_save_ex'] == -1: conf['n_save_ex'] = len(test_set) save_idx = random.sample(range(len(test_set)), conf['n_save_ex']) # If stop_index==-1, evaluate the whole test set if conf['stop_index'] == -1: conf['stop_index'] = len(test_set) # prepare data frames results_applyact = museval.EvalStore() results_adapt = museval.EvalStore() silence_adapt = pd.DataFrame({ 'target': [], 'PES': [], 'EPS': [], 'track': [] }) # Loop over test examples for idx in range(len(test_set)): torch.set_grad_enabled(False) track = test_set.tracks[idx] print(idx, str(track.name)) # Create local directory local_save_dir = os.path.join(ex_save_dir, str(track.name)) os.makedirs(local_save_dir, exist_ok=True) # Load mixture mix = th.from_numpy(track.audio).t().float() ref = mix.mean(dim=0) # mono mixture mix = (mix - ref.mean()) / ref.std() # Load pretrained model klass, args, kwargs, state = torch.load(conf['model_path'], 'cpu') model = klass(*args, **kwargs) model.load_state_dict(state) # Handle device placement if conf['use_gpu']: model.cuda() device = next(model.parameters()).device # Create references matrix references = th.stack([ th.from_numpy(track.targets[name].audio) for name in source_names ]) references = references.numpy() # Get activations H = [] for name in source_names: audio = track.targets[name].audio H.append(audio) H = np.array(H) _, bn_ch1, _ = compute_activation_confidence(H[:, :, 0], theta=conf['th'], hilb=False) _, bn_ch2, _ = compute_activation_confidence(H[:, :, 1], theta=conf['th'], hilb=False) activations = th.from_numpy(np.stack((bn_ch1, bn_ch2), axis=2)) # FINE TUNING if conf['method'] != 'baseline': print('ADAPTATION') torch.set_grad_enabled(True) # Freeze layers freeze(model.encoder) freeze(model.separator, n=conf['frozen_layers']) if conf['freeze_decoder']: freeze(model.decoder) # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=conf['lr_fine']) optimizer = Ranger(filter(lambda p: p.requires_grad, model.parameters()), lr=conf['lr_fine']) loss_func = nn.L1Loss() # Initialize writer for Tensorboard writer = SummaryWriter(log_dir=local_save_dir) for epoch in range(conf['epochs_fine']): total_loss = 0 epoch_loss = 0 total_rec = 0 epoch_rec = 0 total_act = 0 epoch_act = 0 if conf['monitor_metrics']: total_sdr = dict([(key, 0) for key in source_names]) epoch_sdr = dict([(key, 0) for key in source_names]) total_sir = dict([(key, 0) for key in source_names]) epoch_sir = dict([(key, 0) for key in source_names]) total_sar = dict([(key, 0) for key in source_names]) epoch_sar = dict([(key, 0) for key in source_names]) total_isr = dict([(key, 0) for key in source_names]) epoch_isr = dict([(key, 0) for key in source_names]) # Data loader with eventually data augmentation mix_set = DAdataloader(mix.numpy(), win=conf['win_fine'], hop=conf['hop_fine'], sample_rate=conf['sample_rate'], n_observations=conf['n_observations'], pitch_list=conf['pitch_list'], min_semitones=conf['min_semitones'], max_semitones=conf['max_semitones'], same_pitch_list_all_chunks=conf[ 'same_pitch_list_all_chunks']) # Iterate over chuncks for t, item in enumerate(mix_set): sample, win, _ = item mix_chunk = th.from_numpy(sample[None, :, :]).to(device) est_chunk = model(cp(mix_chunk)) act_chunk = activations[None, :, win, :].transpose(3, 2).to(device) loss_act = loss_func(est_chunk * (1 - act_chunk), torch.zeros_like(est_chunk)) if conf['method'] == 'LmixLact': loss_rec = loss_func( mix_chunk, torch.sum(est_chunk * act_chunk, dim=1)) loss = loss_rec + conf['gamma'] * loss_act if conf['method'] == 'Lmix': loss_rec = loss_func(mix_chunk, torch.sum(est_chunk, dim=1)) loss = loss_rec loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() epoch_loss = total_loss / (1 + t) total_rec += loss_rec.item() total_act += loss_act.item() epoch_rec = total_rec / (1 + t) epoch_act = total_act / (1 + t) # Monitor sdr, sir, and sar over epochs if conf['monitor_metrics']: ref_chunk = references[:, win, :] skip = False for i, target in enumerate(source_names): if np.sum(ref_chunk[i, :, :]**2) == 0: skip = True if not skip: sdr, isr, sir, sar = museval.evaluate( ref_chunk, est_chunk.squeeze().transpose( 1, 2).detach().cpu().numpy(), win=np.inf) sdr = np.array(sdr) sir = np.array(sir) sar = np.array(sar) isr = np.array(isr) for i, target in enumerate(source_names): total_sdr[target] += sdr[i] epoch_sdr[target] = total_sdr[target] / (1 + t) total_sir[target] += sir[i] epoch_sir[target] = total_sir[target] / (1 + t) total_sar[target] += sar[i] epoch_sar[target] = total_sar[target] / (1 + t) total_isr[target] += isr[i] epoch_isr[target] = total_isr[target] / (1 + t) if conf['monitor_metrics']: for i, target in enumerate(source_names): writer.add_scalar("SDR/" + target, epoch_sdr[target], epoch) writer.add_scalar("SIR/" + target, epoch_sir[target], epoch) writer.add_scalar("SAR/" + target, epoch_sar[target], epoch) writer.add_scalar("ISR/" + target, epoch_isr[target], epoch) writer.add_scalar("Loss/total", epoch_loss, epoch) writer.add_scalar("Loss/rec", epoch_rec, epoch) writer.add_scalar("Loss/act", epoch_act, epoch) print('epoch, nr of training examples and loss: ', epoch, t, epoch_loss, epoch_rec, epoch_act, epoch_sdr['other']) writer.flush() writer.close() # apply model print('Apply model') estimates = apply_model(model, mix.to(device), shifts=conf['shifts'], split=conf['split']) estimates = estimates * ref.std() + ref.mean() estimates = estimates.transpose(1, 2).cpu().numpy() # get results of this track print('Evaluate model') assert references.shape == estimates.shape track_store, silence_frames = evaluate_mia(ref=references, est=estimates, track_name=track.name, source_names=source_names, eval_silence=True, conf=conf) # aggregate results over the track and save the partials silence_adapt = silence_adapt.append(silence_frames, ignore_index=True) silence_adapt.to_json(os.path.join(conf['exp_dir'], 'silence.json'), orient='records') results_adapt.add_track(track_store) results_adapt.save(os.path.join(conf['exp_dir'], 'bss_eval_tracks.pkl')) print(results_adapt) # Save some examples with corresponding metrics in a folder if idx in save_idx: silence_frames.to_json(os.path.join(local_save_dir, 'silence_frames.json'), orient='records') with open(os.path.join(local_save_dir, 'metrics_museval.json'), 'w+') as f: f.write(track_store.json) sf.write(os.path.join(local_save_dir, "mixture.wav"), mix.transpose(0, 1).cpu().numpy(), conf['sample_rate']) for name, estimate, reference, activation in zip( source_names, estimates, references, activations): print(name) unique, counts = np.unique(activation, return_counts=True) print(dict(zip(unique, counts / (len(activation) * 2) * 100))) assert estimate.shape == reference.shape sf.write(os.path.join(local_save_dir, name + "_est.wav"), estimate, conf['sample_rate']) sf.write(os.path.join(local_save_dir, name + "_ref.wav"), reference, conf['sample_rate']) sf.write(os.path.join(local_save_dir, name + "_act.wav"), activation.cpu().numpy(), conf['sample_rate']) # Evaluate results when applying the activations to the output if conf['apply_act_output']: track_store_applyact, _ = evaluate_mia(ref=references, est=estimates * activations.cpu().numpy(), track_name=track.name, source_names=source_names, eval_silence=False, conf=conf) # aggregate results over the track and save the partials results_applyact.add_track(track_store_applyact) print('after applying activations') print(results_applyact) results_applyact.save( os.path.join(conf['exp_dir'], 'bss_eval_tracks_applyact.pkl')) # Save some examples with corresponding metrics in a folder if idx in save_idx: with open( os.path.join(local_save_dir, 'metrics_museval_applyact.json'), 'w+') as f: f.write(track_store_applyact.json) del track_store_applyact # Delete some variables del references, mix, estimates, track, track_store, silence_frames, model # Stop if reached the limit if idx == conf['stop_index']: break print('------------------') # Print and save aggregated results print('Final results') print(results_adapt) method = museval.MethodStore() method.add_evalstore(results_adapt, conf['exp_dir']) method.save(os.path.join(conf['exp_dir'], 'bss_eval.pkl')) if conf['eval_silence']: print( "mean over evaluation frames, mean over channels, mean over tracks" ) for target in source_names: print( target + ' ==>', silence_adapt.loc[silence_adapt['target'] == target].mean( axis=0, skipna=True)) silence_adapt.to_json(os.path.join(conf['exp_dir'], 'silence.json'), orient='records') print('Final results apply act') print(results_applyact) method = museval.MethodStore() method.add_evalstore(results_applyact, conf['exp_dir']) method.save(os.path.join(conf['exp_dir'], 'bss_eval_applyact.pkl'))
with torch.set_grad_enabled(is_training): for batch_no, batch in enumerate(loader): steps[name] += 1 opt.zero_grad() # y = batch['site1']['targets'].to(device) y = batch['site1']['targets_one_hot'].to(device) out = model(batch['site1']['features'].to(device), batch['site2']['features'].to(device)) if is_training: global_step += 1 loss = loss_fn(out, y) loss.backward() opt.step() #if (epoch == (warmup_steps - 1)) and batch_no == (len(loader) - 1): # pass # skip #else: # sched.step(global_step) sched.step() curr_lr = opt.param_groups[0]['lr'] vis.line(X=[steps[name]], Y=[curr_lr], win='lr', name='lr', update='append') avg_loss = rolling_loss[name](loss.item(), steps[name])
def train_alphaBert_stage1(TS_model, dloader, testloader, lr=1e-4, epoch=10, log_interval=20, cloze_fix=True, use_amp=False, lkahead=False, parallel=True): global checkpoint_file TS_model.to(device) # model_optimizer = optim.Adam(TS_model.parameters(), lr=lr) # if lkahead: # print('using Lookahead') # model_optimizer = lookahead_pytorch.Lookahead(model_optimizer, la_steps=5, la_alpha=0.5) model_optimizer = Ranger(TS_model.parameters(), lr=lr) if use_amp: TS_model, model_optimizer = amp.initialize(TS_model, model_optimizer, opt_level="O1") if parallel: TS_model = torch.nn.DataParallel(TS_model) # torch.distributed.init_process_group(backend='nccl', # init_method='env://host', # world_size=0, # rank=0, # store=None, # group_name='') # TS_model = DDP(TS_model) # TS_model = apex.parallel.DistributedDataParallel(TS_model) TS_model.train() # criterion = alphabert_loss.Alphabert_satge1_loss(device=device) criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device) iteration = 0 total_loss = [] out_pred_res = [] out_pred_test = [] for ep in range(epoch): t0 = time.time() # step_loss = 0 epoch_loss = 0 epoch_cases = 0 for batch_idx, sample in enumerate(dloader): # TS_model.train() model_optimizer.zero_grad() loss = 0 src = sample['src_token'] trg = sample['trg'] att_mask = sample['mask_padding'] origin_len = sample['origin_seq_length'] bs, max_len = src.shape # src, err_cloze = make_cloze(src, # max_len, # device=device, # percent=0.15, # fix=cloze_fix) src = src.float().to(device) trg = trg.long().to(device) att_mask = att_mask.float().to(device) origin_len = origin_len.to(device) prediction_scores, = TS_model(input_ids=src, attention_mask=att_mask) # print(1111,prediction_scores.view(-1,84).shape) # print(1111,trg.view(-1).shape) loss = criterion( prediction_scores.view(-1, 100).contiguous(), trg.view(-1).contiguous()) if use_amp: with amp.scale_loss(loss, model_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() model_optimizer.step() with torch.no_grad(): epoch_loss += loss.item() * bs epoch_cases += bs if iteration % log_interval == 0: print('Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'. format(ep, batch_idx * batch_size, 100. * batch_idx / len(dloader), (time.time() - t0) * len(dloader) / (60 * (batch_idx + 1)), loss.item())) if iteration % 400 == 0: save_checkpoint(checkpoint_file, 'd2s_total.pth', TS_model, model_optimizer, parallel=parallel) a_ = tokenize_alphabets.convert_idx2str( src[0][:origin_len[0]]) print(a_) print(' ******** ******** ******** ') _, show_pred = torch.max(prediction_scores[0], dim=1) err_cloze_ = trg[0] > -1 src[0][err_cloze_] = show_pred[err_cloze_].float() b_ = tokenize_alphabets.convert_idx2str( src[0][:origin_len[0]]) print(b_) print(' ******** ******** ******** ') src[0][err_cloze_] = trg[0][err_cloze_].float() c_ = tokenize_alphabets.convert_idx2str( src[0][:origin_len[0]]) print(c_) out_pred_res.append((ep, a_, b_, c_, err_cloze_)) out_pd_res = pd.DataFrame(out_pred_res) out_pd_res.to_csv('./result/out_pred_train.csv', sep=',') if iteration % 999 == 0: print(' ===== Show the Test of Pretrain ===== ') test_res = test_alphaBert_stage1(TS_model, testloader) print(' ===== Show the Test of Pretrain ===== ') out_pred_test.append((ep, *test_res)) out_pd_test = pd.DataFrame(out_pred_test) out_pd_test.to_csv('./result/out_pred_test.csv', sep=',') iteration += 1 if ep % 1 == 0: save_checkpoint(checkpoint_file, 'd2s_total.pth', TS_model, model_optimizer, parallel=parallel) print('======= epoch:%i ========' % ep) print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0)) total_loss.append(float(epoch_loss / epoch_cases)) pd_total_loss = pd.DataFrame(total_loss) pd_total_loss.to_csv('./result/total_loss_pretrain.csv', sep=',') print(total_loss)
def train_fold(): #get arguments opts = get_args() #gpu selection os.environ["CUDA_VISIBLE_DEVICES"] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #instantiate datasets json_path = os.path.join(opts.path, 'train.json') json = pd.read_json(json_path, lines=True) train_ids = json.id.to_list() json_path = os.path.join(opts.path, 'test.json') test = pd.read_json(json_path, lines=True) #aug_test=test #dataloader ls_indices = test.seq_length == 130 long_data = test[ls_indices] ids = np.asarray(long_data.id.to_list()) long_dataset = RNADataset(long_data.sequence.to_list(), np.zeros(len(ls_indices)), ids, np.arange(len(ls_indices)), opts.path) long_dataloader = DataLoader(long_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.workers) ss_indices = test.seq_length == 107 short_data = test[ss_indices] ids = short_data.id.to_list() ids = ids + train_ids short_sequences = short_data.sequence.to_list() + json.sequence.to_list() short_dataset = RNADataset(short_sequences, np.zeros(len(short_sequences)), ids, np.arange(len(short_sequences)), opts.path) short_dataloader = DataLoader(short_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.workers) #checkpointing checkpoints_folder = 'pretrain_weights' csv_file = 'pretrain.csv'.format((opts.fold)) columns = ['epoch', 'train_loss'] logger = CSVLogger(columns, csv_file) #build model and logger model = NucleicTransformer(opts.ntoken, opts.nclass, opts.ninp, opts.nhead, opts.nhid, opts.nlayers, opts.kmer_aggregation, kmers=opts.kmers, stride=opts.stride, dropout=opts.dropout, pretrain=True).to(device) optimizer = Ranger(model.parameters(), weight_decay=opts.weight_decay) #optimizer=torch.optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss() #lr_schedule=lr_AIAYN(optimizer,opts.ninp,opts.warmup_steps,opts.lr_scale) # Mixed precision initialization opt_level = 'O1' #model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model = nn.DataParallel(model) pytorch_total_params = sum(p.numel() for p in model.parameters()) print('Total number of paramters: {}'.format(pytorch_total_params)) #training loop cos_epoch = int(opts.epochs * 0.75) total_steps = len(long_dataloader) + len(short_dataloader) lr_schedule = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, (opts.epochs - cos_epoch) * (total_steps)) for epoch in range(opts.epochs): model.train(True) t = time.time() total_loss = 0 optimizer.zero_grad() train_preds = [] ground_truths = [] step = 0 for data in short_dataloader: #for step in range(1): step += 1 lr = get_lr(optimizer) src = data['data'] labels = data['labels'] bpps = data['bpp'].to(device) if np.random.uniform() > 0.5: masked = mutate_rna_input(src) else: masked = mask_rna_input(src) src = src.to(device).long() masked = masked.to(device).long() #labels=labels.to(device).float() output = model(masked, bpps) #ew=data['ew'].to(device) loss=(criterion(output[0].reshape(-1,4),src[:,:,0].reshape(-1))+\ criterion(output[1].reshape(-1,3),src[:,:,1].reshape(-1)-4)+\ criterion(output[2].reshape(-1,7),src[:,:,2].reshape(-1)-7)) # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) optimizer.step() optimizer.zero_grad() total_loss += loss if epoch > cos_epoch: lr_schedule.step() print( "Epoch [{}/{}], Step [{}/{}] Loss: {:.3f} Lr:{:.6f} Time: {:.1f}" .format(epoch + 1, opts.epochs, step + 1, total_steps, total_loss / (step + 1), lr, time.time() - t), end='\r', flush=True) #total_loss/(step+1) for data in long_dataloader: #for step in range(1): step += 1 lr = get_lr(optimizer) src = data['data'] labels = data['labels'] bpps = data['bpp'].to(device) if np.random.uniform() > 0.5: masked = mutate_rna_input(src) else: masked = mask_rna_input(src) src = src.to(device).long() masked = masked.to(device).long() #labels=labels.to(device).float() output = model(masked, bpps) #ew=data['ew'].to(device) loss=(criterion(output[0].reshape(-1,4),src[:,:,0].reshape(-1))+\ criterion(output[1].reshape(-1,3),src[:,:,1].reshape(-1)-4)+\ criterion(output[2].reshape(-1,7),src[:,:,2].reshape(-1)-7)) # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) optimizer.step() optimizer.zero_grad() total_loss += loss if epoch > cos_epoch: lr_schedule.step() print( "Epoch [{}/{}], Step [{}/{}] Loss: {:.3f} Lr:{:.6f} Time: {:.1f}" .format(epoch + 1, opts.epochs, step + 1, total_steps, total_loss / (step + 1), lr, time.time() - t), end='\r', flush=True) #total_loss/(step+1) #break # if epoch > cos_epoch: # lr_schedule.step() print('') train_loss = total_loss / (step + 1) torch.cuda.empty_cache() to_log = [ epoch + 1, train_loss, ] logger.log(to_log) if (epoch + 1) % opts.save_freq == 0: save_weights(model, optimizer, epoch, checkpoints_folder) get_best_weights_from_fold(opts.fold)
def train_model(dataset=dataset, save_dir=save_dir, num_classes=num_classes, lr=lr, num_epochs=nEpochs, save_epoch=snapshot, useTest=useTest, test_interval=nTestInterval): """ Args: num_classes (int): Number of classes in the data num_epochs (int, optional): Number of epochs to train for. """ file = open('run/log.txt', 'w') if modelName == 'C3D': model = C3D(num_class=num_classes) model.my_load_pretrained_weights('saved_model/c3d.pickle') train_params = model.parameters() # train_params = [{'params': get_1x_lr_params(model), 'lr': lr}, # {'params': get_10x_lr_params(model), 'lr': lr * 10}] # elif modelName == 'R2Plus1D': # model = R2Plus1D_model.R2Plus1DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2)) # train_params = [{'params': R2Plus1D_model.get_1x_lr_params(model), 'lr': lr}, # {'params': R2Plus1D_model.get_10x_lr_params(model), 'lr': lr * 10}] # elif modelName == 'R3D': # model = R3D_model.R3DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2)) # train_params = model.parameters() elif modelName == 'Res3D': # model = Resnet(num_classes=num_classes, block=resblock, layers=[3, 4, 6, 3]) # train_params=model.parameters() model = generate_model(50) model = load_pretrained_model(model, './saved_model/r3d50_K_200ep.pth', n_finetune_classes=num_classes) train_params = model.parameters() else: print('We only implemented C3D and R2Plus1D models.') raise NotImplementedError criterion = nn.CrossEntropyLoss( ) # standard crossentropy loss for classification # optimizer = torch.optim.Adam(train_params, lr=lr, betas=(0.9, 0.999), weight_decay=1e-5, # amsgrad=True) optimizer = Ranger(train_params, lr=lr, betas=(.95, 0.999), weight_decay=5e-4) print('use ranger') scheduler = CosineAnnealingLR(optimizer, T_max=32, eta_min=0, last_epoch=-1) # optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4) # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, # gamma=0.1) # the scheduler divides the lr by 10 every 10 epochs if resume_epoch == 0: print("Training {} from scratch...".format(modelName)) else: checkpoint = torch.load(os.path.join( save_dir, 'models', saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar'), map_location=lambda storage, loc: storage ) # Load all tensors onto the CPU print("Initializing weights from: {}...".format( os.path.join( save_dir, 'models', saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar'))) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['opt_dict']) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # model.to(device) if torch.cuda.is_available(): model = model.cuda() torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True model = nn.DataParallel(model) criterion.cuda() # log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) log_dir = os.path.join(save_dir) writer = SummaryWriter(log_dir=log_dir) print('Training model on {} dataset...'.format(dataset)) train_dataloader = DataLoader(VideoDataset(dataset=dataset, split='train', clip_len=16), batch_size=8, shuffle=True, num_workers=8) val_dataloader = DataLoader(VideoDataset(dataset=dataset, split='validation', clip_len=16), batch_size=8, num_workers=8) test_dataloader = DataLoader(VideoDataset(dataset=dataset, split='test', clip_len=16), batch_size=8, num_workers=8) trainval_loaders = {'train': train_dataloader, 'val': val_dataloader} trainval_sizes = { x: len(trainval_loaders[x].dataset) for x in ['train', 'val'] } test_size = len(test_dataloader.dataset) # my_smooth={'0': 0.88, '1': 0.95, '2': 0.96, '3': 0.79, '4': 0.65, '5': 0.89, '6': 0.88} for epoch in range(resume_epoch, num_epochs): # each epoch has a training and validation step for phase in ['train', 'val']: start_time = timeit.default_timer() # reset the running loss and corrects running_loss = 0.0 running_corrects = 0.0 # set model to train() or eval() mode depending on whether it is trained # or being validated. Primarily affects layers such as BatchNorm or Dropout. if phase == 'train': # scheduler.step() is to be called once every epoch during training # scheduler.step() model.train() else: model.eval() for inputs, labels in tqdm(trainval_loaders[phase]): # move inputs and labels to the device the training is taking place on inputs = Variable(inputs, requires_grad=True).to(device) labels = Variable(labels).to(device) # inputs = inputs.cuda(non_blocking=True) # labels = labels.cuda(non_blocking=True) optimizer.zero_grad() if phase == 'train': outputs = model(inputs) else: with torch.no_grad(): outputs = model(inputs) probs = nn.Softmax(dim=1)(outputs) # the size of output is [bs , 7] preds = torch.max(probs, 1)[1] # preds is the index of maxnum of output # print(outputs) # print(torch.max(outputs, 1)) loss = criterion(outputs, labels) if phase == 'train': loss.backward() optimizer.step() scheduler.step(loss) # for name, parms in model.named_parameters(): # print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \ # ' -->grad_value:', parms.grad) # print('-->name:', name, ' -->grad_value:', parms.grad) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) print('\ntemp/label:{}/{}'.format(preds[0], labels[0])) epoch_loss = running_loss / trainval_sizes[phase] epoch_acc = running_corrects.double() / trainval_sizes[phase] if phase == 'train': writer.add_scalar('data/train_loss_epoch', epoch_loss, epoch) writer.add_scalar('data/train_acc_epoch', epoch_acc, epoch) else: writer.add_scalar('data/val_loss_epoch', epoch_loss, epoch) writer.add_scalar('data/val_acc_epoch', epoch_acc, epoch) print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format( phase, epoch + 1, nEpochs, epoch_loss, epoch_acc)) stop_time = timeit.default_timer() print("Execution time: " + str(stop_time - start_time) + "\n") file.write("\n[{}] Epoch: {}/{} Loss: {} Acc: {}".format( phase, epoch + 1, nEpochs, epoch_loss, epoch_acc)) if epoch % save_epoch == (save_epoch - 1): torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict(), }, os.path.join(save_dir, saveName + '_epoch-' + str(epoch) + '.pth.tar')) print("Save model at {}\n".format( os.path.join(save_dir, saveName + '_epoch-' + str(epoch) + '.pth.tar'))) if useTest and epoch % test_interval == (test_interval - 1): model.eval() start_time = timeit.default_timer() running_loss = 0.0 running_corrects = 0.0 for inputs, labels in tqdm(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) with torch.no_grad(): outputs = model(inputs) probs = nn.Softmax(dim=1)(outputs) preds = torch.max(probs, 1)[1] loss = criterion(outputs, labels) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / test_size epoch_acc = running_corrects.double() / test_size writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch) writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch) print("[test] Epoch: {}/{} Loss: {} Acc: {}".format( epoch + 1, nEpochs, epoch_loss, epoch_acc)) stop_time = timeit.default_timer() print("Execution time: " + str(stop_time - start_time) + "\n") file.write("\n[test] Epoch: {}/{} Loss: {} Acc: {}\n".format( epoch + 1, nEpochs, epoch_loss, epoch_acc)) writer.close() file.close()
def train(model, exp_id): save_model_dir = f'experiment/{exp_id}/save_models' try: os.mkdir(save_model_dir) except Exception as e: print(e) # ==== INIT MODEL================= device = get_device() model.to(device) optimizer = Ranger(model.parameters(), lr=cfg["model_params"]["lr"]) train_dataloader = load_train_data() start = time.time() train_result = { 'epoch': [], 'iter': [], 'loss[-k:](avg)': [], 'validation_loss': [], 'time(minute)': [], } k = 1000 lossk = [] torch.set_grad_enabled(True) num_iter = cfg["train_params"]["max_num_steps"] for epoch in range(cfg['train_params']['epoch']): model.train() torch.set_grad_enabled(True) tr_it = iter(train_dataloader) train_progress_bar = tqdm(range(num_iter)) optimizer.zero_grad() print('epoch:', epoch) for i in train_progress_bar: try: data = next(tr_it) preds, confidences = forward(data, model, device) target_availabilities = data["target_availabilities"].to( device) targets = data["target_positions"].to(device) loss = criterion(targets, preds, confidences, target_availabilities) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # total_loss.append(loss.item()) lossk.append(loss.item()) if len(lossk) > k: lossk.pop(0) train_progress_bar.set_description( f"loss: {loss.item():.4f} loss[-k:](avg): {np.mean(lossk):.4f}" ) if ((i > 0 and i % cfg['train_params']['checkpoint_steps'] == 0) or i == num_iter - 1): # save model per checkpoint torch.save( model.state_dict(), f'{save_model_dir}/epoch{epoch:02d}_iter{i:05d}.pth') # save_result train_result['epoch'].append(epoch) train_result['iter'].append(i) train_result['loss[-k:](avg)'].append(np.mean(lossk)) train_result['validation_loss'].append(-1) train_result['time(minute)'].append( (time.time() - start) / 60) checkpoint_loss = [] if i > 0 and i % cfg['train_params']['validation_steps'] == 0: validation_loss = validation(model, device) train_result['epoch'].append(epoch) train_result['iter'].append(-1) train_result['loss[-k:](avg)'].append(-1) train_result['validation_loss'].append(validation_loss) train_result['time(minute)'].append( (time.time() - start) / 60) model.train() torch.set_grad_enabled(True) except KeyboardInterrupt: torch.save( model.state_dict(), f'{save_model_dir}/interrupt_epoch{epoch:02d}_iter{i:05d}.pth' ) # save train result results = pd.DataFrame(train_result) results.to_csv( f"experiment/{exp_id}/interrupt_train_result.csv", index=False) print(f"Total training time is {(time.time()-start)/60} mins") print(results) raise KeyboardInterrupt del tr_it, train_progress_bar # save train result results = pd.DataFrame(train_result) results.to_csv(f"experiment/{exp_id}/train_result.csv", index=False) print(f"Total training time is {(time.time()-start)/60} mins") print(results)
class Learner: def __init__(self, autoencoder: LSTMAutoEncoder, train_loader: DataLoader, val_loader: DataLoader, cfg: Config): self.net = autoencoder self.train_loader = train_loader self.val_loader = val_loader # get from config object output_dir = os.path.join(cfg.OUTPUT_DIR, cfg.EXPERIMENT_NAME) os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir self.device = cfg.DEVICE self.epoch_n = cfg.EPOCH_N self.save_cycle = cfg.SAVE_CYCLE self.verbose_cycle = cfg.VERBOSE_CYCLE self.encoder_lr = cfg.ENCODER_LR self.decoder_lr = cfg.DECODER_LR self.encoder_gamma = cfg.ENCODER_GAMMA self.decoder_gamma = cfg.DECODER_GAMMA self.encoder_step_cycle = cfg.ENCODER_STEP_CYCLE self.decoder_step_cycle = cfg.DECODER_STEP_CYCLE # set optimizer and scheduler self.encoder_optim = Ranger(params=filter(lambda p: p.requires_grad, self.encoder.parameters()), lr=self.encoder_lr) self.decoder_optim = Ranger(params=filter(lambda p: p.requires_grad, self.decoder.parameters()), lr=self.encoder_lr) self.encoder_stepper = StepLR(self.encoder_optim, step_size=self.encoder_step_cycle, gamma=self.encoder_gamma) self.decoder_stepper = StepLR(self.decoder_optim, step_size=self.decoder_step_cycle, gamma=self.decoder_gamma) self.loss = nn.MSELoss() # for book-keeping self.crt_epoch = 0 self.train_losses = [] self.val_losses = [] @property def encoder(self): return self.net.encoder @property def decoder(self): return self.net.decoder @property def signature(self): return f'[Epoch: {self.crt_epoch}]' @property def model_path(self): model_name = f'lstm_ae_{self.crt_epoch:04}.pth' _path = os.path.join(self.output_dir, model_name) return _path @property def csv_path(self): csv_name = f'report.csv' return os.path.join(self.output_dir, csv_name) def train(self): logger.info('Start training...') self.net.to(self.device) for epoch_i in range(self.epoch_n): self.crt_epoch = epoch_i + 1 start_t = time.time() epoch_min = self._train_one_epoch() if self.crt_epoch % self.verbose_cycle == 0: train_avg_rmse = self.train_losses[-1] logger.info( f'{self.signature}:: Train complete. Avg RMSE: {train_avg_rmse:04f}. Time: {epoch_min:03f} mins' ) epoch_min = self._val_one_epoch() if self.crt_epoch % self.verbose_cycle == 0: val_avg_rmse = self.val_losses[-1] logger.info( f'{self.signature}:: Val complete. Avg RMSE: {val_avg_rmse:04f}. Time: {epoch_min:03f} mins' ) total_epoch_min = (time.time() - start_t) / 60. if self.crt_epoch % self.verbose_cycle == 0: logger.info( f'{self.signature}:: Completed Train + Val. Time: {total_epoch_min} mins' ) if self.crt_epoch % self.save_cycle == 0: self.save_model() self.save_report() logger.info( f'{self.signature}:: Model saved: {self.model_path}') logger.info(f'{self.signature}:: Training complete!') self.save_model() self.save_report() logger.info( f'{self.signature}:: Final Model and Report saved: {self.model_path}, {self.csv_path}' ) def _train_one_epoch(self): self.net.train() start_t = time.time() rmse_ls = [] for bboxs, seq_lens, classes in tqdm(self.train_loader): self.encoder_optim.zero_grad() self.decoder_optim.zero_grad() bboxs = bboxs.to(self.device) preds = self.net(bboxs) preds = pack_padded_sequence(preds, seq_lens, batch_first=True, enforce_sorted=True) targets = pack_padded_sequence(bboxs.clone(), seq_lens, batch_first=True, enforce_sorted=True) loss = self.loss(preds.data, targets.data) rmse = torch.sqrt(loss) rmse.backward() self.encoder_optim.step() self.decoder_optim.step() self.encoder_stepper.step() self.decoder_stepper.step() report_rmse = float(rmse.data.cpu().numpy()) rmse_ls.append(report_rmse) epoch_min = (time.time() - start_t) / 60. avg_rmse = sum(rmse_ls) / len(rmse_ls) self.train_losses.append(avg_rmse) return epoch_min def _val_one_epoch(self): self.net.eval() start_t = time.time() rmse_ls = [] with torch.no_grad(): for bboxs, seq_lens, classes in tqdm(self.val_loader): bboxs = bboxs.to(self.device) preds = self.net(bboxs) preds = pack_padded_sequence(preds, seq_lens, batch_first=True, enforce_sorted=True) targets = pack_padded_sequence(bboxs.clone(), seq_lens, batch_first=True, enforce_sorted=True) loss = self.loss(preds.data, targets.data) rmse = torch.sqrt(loss) report_rmse = float(rmse.data.cpu().numpy()) rmse_ls.append(report_rmse) epoch_min = (time.time() - start_t) / 60. avg_rmse = sum(rmse_ls) / len(rmse_ls) self.val_losses.append(avg_rmse) return epoch_min def save_model(self): torch.save(self.net.state_dict(), self.model_path) def load_model(self, ckpt_path): assert os.path.isfile(ckpt_path), f'Non-exist ckpt_path: {ckpt_path}' self.net.load_state_dict(torch.load(ckpt_path)) logger.info(f'Model loaded: {ckpt_path}') def save_report(self): losses_dicts = { 'train_losses': self.train_losses, 'val_losses': self.val_losses } df = pd.DataFrame(losses_dicts) df.to_csv(self.csv_path, index=False)
def main(args): # img = plt.imread(train_img_list[0]) # plt.imshow(img) '''create directory to save trained model and other info''' if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ''' setup GPU ''' # torch.cuda.set_device(args.gpu) ''' setup random seed ''' np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") iters = 0 best_acc = 0 best_recall = 0 train_loss = [] total_wrecall = [0] total_acc = [0] ''' load dataset and prepare data loader ''' print('===> loading data') train_set = MyDataset(r'E:\ACV\MangoClassify', 'C1-P1_Train', 'train.csv', 'train') test_set = MyDataset(r'E:\ACV\MangoClassify', 'C1-P1_Dev', 'dev.csv', 'test') # print(train_set[0]) train_loss = [] print('===> build dataloader ...') train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.train_batch, num_workers=args.workers, shuffle=False) test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=args.test_batch, num_workers=args.workers, shuffle=False) ''' load model ''' print('===> prepare model ...') model = CNN().to(device) ''' define loss ''' criterion = nn.CrossEntropyLoss() ''' setup optimizer ''' # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = Ranger(model.parameters(), lr=args.lr) ''' train model ''' print('===> start training ...') for epoch in range(1, args.epoch + 1): # model.train() for idx, (imgs, label) in enumerate(train_loader): # print(imgs, label) train_info = 'Epoch: [{0}][{1}/{2}]'.format( epoch, idx + 1, len(train_loader)) iters += 1 ''' move data to gpu ''' # print('===> load data to gpu') imgs = imgs.permute(0, 3, 1, 2).to(device, dtype=torch.float) label = label.to(device) ''' forward path ''' pred = model(imgs) ''' compute loss, backpropagation, update parameters ''' # print('===> calculate loss') loss = criterion(pred, label) # compute loss train_loss += [loss.item()] torch.cuda.empty_cache() optimizer.zero_grad() # set grad of all parameters to zero loss.backward() # compute gradient for each parameters optimizer.step() # update parameters train_info += ' loss: {:.8f}'.format(loss.item()) print(train_info) #, end="\r") if epoch % args.val_epoch == 0: ''' evaluate the model ''' test_info = 'Epoch: [{}] '.format(epoch) model.eval() correct = 0 total = 0 tp_A = 0 tp_B = 0 tp_C = 0 fn_A = 0 fn_B = 0 fn_C = 0 # loss = 0 for idx, (imgs, label) in enumerate(test_loader): imgs = imgs.permute(0, 3, 1, 2).to(device, dtype=torch.float) # gt = gt.to(device) pred = model(imgs) # torch.cuda.empty_cache() # loss += criterion(output, gt).item() a, b, c, d, e, f, g, h, i = confusion_matrix( label.detach().numpy(), pred.argmax(-1).cpu().detach().numpy()).ravel() tp_A += a fn_A += (a + b + c) tp_B += e fn_B += (e + d + f) tp_C += i fn_C += (i + g + h) correct += (a + e + i) total += len(label) acc = correct / total w_recall = ((tp_A / fn_A) + (tp_B / fn_B) + (tp_B / fn_B)) / 3 total_wrecall += [w_recall] total_acc += [acc] test_info += 'Acc:{:.8f} '.format(acc) test_info += 'Recall:{:.8f} '.format(w_recall) print(test_info) # print(tn, fp, fn, tp) ''' save best model ''' if w_recall > best_recall: best_recall = w_recall save_model(model, os.path.join(args.save_dir, 'model_best_recall.h5')) if acc > best_acc: best_acc = acc save_model(model, os.path.join(args.save_dir, 'model_best_acc.h5')) ''' save model ''' save_model( model, os.path.join( args.save_dir, 'model_{}_acc={:8f}_recall={:8f}.h5'.format( epoch, acc, w_recall))) plt.figure() plt.plot(range(1, len(train_loss) + 1), train_loss, '-') plt.xlabel("iteration") plt.ylabel("loss") plt.title("training loss") plt.show() plt.figure() plt.plot(range(1, len(total_acc) + 1), total_acc, '-') plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.title("Best Accuracy:" + str(best_acc)) plt.show() plt.figure() plt.plot(range(1, len(total_wrecall) + 1), total_wrecall, '-') plt.xlabel("Epoch") plt.ylabel("Recall") plt.title("Best Recall:" + str(best_recall)) plt.show()
def main(args, logger): writer = SummaryWriter(args.subTensorboardDir) model = Vgg().to(device) trainSet = Lung(rootDir=args.dataDir, mode='train', size=args.inputSize) valSet = Lung(rootDir=args.dataDir, mode='test', size=args.inputSize) trainDataloader = DataLoader(trainSet, batch_size=args.batchSize, drop_last=True, shuffle=True, pin_memory=False, num_workers=args.numWorkers) valDataloader = DataLoader(valSet, batch_size=args.valBatchSize, drop_last=False, shuffle=False, pin_memory=False, num_workers=args.numWorkers) criterion = nn.CrossEntropyLoss() optimizer = Ranger(model.parameters(), lr=args.lr) model, optimizer = amp.initialize(model, optimizer, opt_level=args.apexType) iter = 0 runningLoss = [] for epoch in range(args.epoch): if epoch != 0 and epoch % args.evalFrequency == 0: f1, acc = eval(model, valDataloader, logger) writer.add_scalars('f1_acc', {'f1': f1, 'acc': acc}, iter) if epoch != 0 and epoch % args.saveFrequency == 0: modelName = osp.join(args.subModelDir, 'out_{}.pt'.format(epoch)) # 防止分布式训练保存失败 stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict() torch.save(stateDict, modelName) checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict() } torch.save(checkpoint, modelName) for img, lb, _ in trainDataloader: # array = np.array(img) # for i in range(array.shape[0]): # plt.imshow(array[i, 0, ...], cmap='gray') # plt.show() iter += 1 img, lb = img.to(device), lb.to(device) optimizer.zero_grad() outputs = model(img) loss = criterion(outputs.squeeze(), lb.long()) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() # loss.backward() optimizer.step() runningLoss.append(loss.item()) if iter % args.msgFrequency == 0: avgLoss = np.mean(runningLoss) runningLoss = [] lr = optimizer.param_groups[0]['lr'] logger.info(f'epoch: {epoch} / {args.epoch}, ' f'iter: {iter} / {len(trainDataloader) * args.epoch}, ' f'lr: {lr}, ' f'loss: {avgLoss:.4f}') writer.add_scalar('loss', avgLoss, iter) eval(model, valDataloader, logger) modelName = osp.join(args.subModelDir, 'final.pth') stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict() torch.save(stateDict, modelName)
def train(train_data, val_data, fold_idx=None): train_dataset = MyDataset(train_data, tokenizer) val_dataset = MyDataset(val_data, tokenizer) train_loader = DataLoader(train_dataset, batch_size=config.batch_size) val_loader = DataLoader(val_dataset, batch_size=config.batch_size) model = Model().to(config.device) # optimizer = torch.optim.Adam(model.parameters(), lr=model_config.learning_rate) # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1) optimizer = Ranger(model.parameters(), lr=5e-5) period = int(len(train_loader) / config.train_print_step) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=period, eta_min=5e-9) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.train_print_step, eta_min=1e-9) if fold_idx is None: print('start') model_save_path = os.path.join(config.model_path, '{}.bin'.format(model_name)) else: print('start fold: {}'.format(fold_idx + 1)) model_save_path = os.path.join( config.model_path, '{}_fold{}.bin'.format(model_name, fold_idx)) best_val_score = 0 last_improved_epoch = 0 adjust_lr_num = 0 y_true_list = [] y_pred_list = [] for cur_epoch in range(config.epochs_num): start_time = int(time.time()) model.train() print('epoch:{}, step:{}'.format(cur_epoch + 1, len(train_loader))) cur_step = 0 # for batch_x, batch_y in train_loader: for inputs in train_loader: tweet = inputs["tweet"] selected_text = inputs["selected_text"] sentiment = inputs["sentiment"] ids = inputs["ids"] attention_mask = inputs["attention_mask"] token_type_ids = inputs["token_type_ids"] targets_start_idx = inputs["start_idx"] targets_end_idx = inputs["end_idx"] offsets = inputs["offsets"] # Move ids, masks, and targets to gpu while setting as torch.long ids = ids.to(config.device, dtype=torch.long) token_type_ids = token_type_ids.to(config.device, dtype=torch.long) attention_mask = attention_mask.to(config.device, dtype=torch.long) targets_start_idx = targets_start_idx.to(config.device, dtype=torch.long) targets_end_idx = targets_end_idx.to(config.device, dtype=torch.long) optimizer.zero_grad() start_logits, end_logits = model( input_ids=ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) loss = loss_fn(start_logits, end_logits, targets_start_idx, targets_end_idx) loss.backward() optimizer.step() cur_step += 1 # pred_start_idxs = torch.argmax(start_logits, dim=1).cpu().data.numpy() # pred_end_idxs = torch.argmax(end_logits, dim=1).cpu().data.numpy() # for i in range(len(tweet)): # y_true_list.append((tweet[i], selected_text[i], sentiment[i], offsets[i])) # y_pred_list.append((pred_start_idxs[i], pred_end_idxs[i])) if cur_step % config.train_print_step == 0: scheduler.step() msg = 'the current step: {0}/{1}, cost: {2}s' print( msg.format(cur_step, len(train_loader), int(time.time()) - start_time)) # train_score = get_score(y_true_list, y_pred_list) # msg = 'the current step: {0}/{1}, train score: {2:>6.2%}' # print(msg.format(cur_step, len(train_loader), train_score)) # y_true_list = [] # y_pred_list = [] val_loss, val_score = evaluate(model, val_loader) if val_score >= best_val_score: best_val_score = val_score torch.save(model.state_dict(), model_save_path) improved_str = '*' last_improved_epoch = cur_epoch else: improved_str = '' msg = 'the current epoch: {0}/{1}, val loss: {2:>5.2}, val acc: {3:>6.2%}, cost: {4}s {5}' end_time = int(time.time()) print( msg.format(cur_epoch + 1, config.epochs_num, val_loss, val_score, end_time - start_time, improved_str)) if cur_epoch - last_improved_epoch >= config.patience_epoch: if adjust_lr_num >= model_config.adjust_lr_num: print("No optimization for a long time, auto stopping...") break print("No optimization for a long time, adjust lr...") last_improved_epoch = cur_epoch # 加上,不然会连续更新的 adjust_lr_num += 1 del model gc.collect() if fold_idx is not None: model_score[fold_idx] = best_val_score
sr=lab.view(-1) sr=torch.sort(sr)[0] print ("sr",sr.shape) plt.plot(sr[-1000:].numpy()) plt.savefig("%s/labvalues_px%s_cr%s_BN%s.png" % (savePath, opt.imageSize, opt.imageCrop,str(opt.BN))) plt.close() #raise Exception lab = lab.to(device).float() # not long but float... opti.zero_grad() pred = netD(im) err = criterion(pred, lab) err.backward() opti.step() buf.append(err.item()) for j in range(lab.shape[0]):##single element, 1 number per patch if lab[j].sum()==0: bufN.append(pred[j].max().item()) bufME.append(pred[j].mean().item()) else: bufP.append(pred[j].max().item()) if i%20 ==0 and False: print("lab", lab.sum(3).sum(2), "pred", pred.sum(3).sum(2)) print ("buf",len(bufN),len(bufP)) if i%5==0:
def train(train_data, val_data, fold_idx=None): train_data = MyDataset(train_data, train_transform) train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True) val_data = MyDataset(val_data, val_transform) val_loader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False) model = Net(model_name).to(device) # criterion = nn.CrossEntropyLoss() criterion = FocalLoss(0.5) # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1) optimizer = Ranger(model.parameters(), lr=1e-3, weight_decay=0.0005) # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=4) if fold_idx is None: print('start') model_save_path = os.path.join(config.model_path, '{}.bin'.format(model_name)) else: print('start fold: {}'.format(fold_idx + 1)) model_save_path = os.path.join(config.model_path, '{}_fold{}.bin'.format(model_name, fold_idx)) # if os.path.isfile(model_save_path): # print('加载之前的训练模型') # model.load_state_dict(torch.load(model_save_path)) best_val_score = 0 best_val_score_cnt = 0 last_improved_epoch = 0 adjust_lr_num = 0 for cur_epoch in range(config.epochs_num): start_time = int(time.time()) model.train() print('epoch:{}, step:{}'.format(cur_epoch + 1, len(train_loader))) cur_step = 0 for batch_x, batch_y in train_loader: batch_x, batch_y = batch_x.to(device), batch_y.to(device) optimizer.zero_grad() probs = model(batch_x) train_loss = criterion(probs, batch_y) train_loss.backward() optimizer.step() cur_step += 1 if cur_step % config.train_print_step == 0: train_acc = accuracy(probs, batch_y) msg = 'the current step: {0}/{1}, train loss: {2:>5.2}, train acc: {3:>6.2%}' print(msg.format(cur_step, len(train_loader), train_loss.item(), train_acc[0].item())) val_loss, val_score = evaluate(model, val_loader, criterion) if val_score >= best_val_score: if val_score == best_val_score: best_val_score_cnt += 1 best_val_score = val_score torch.save(model.state_dict(), model_save_path) improved_str = '*' last_improved_epoch = cur_epoch else: improved_str = '' msg = 'the current epoch: {0}/{1}, val loss: {2:>5.2}, val acc: {3:>6.2%}, cost: {4}s {5}' end_time = int(time.time()) print(msg.format(cur_epoch + 1, config.epochs_num, val_loss, val_score, end_time - start_time, improved_str)) if cur_epoch - last_improved_epoch >= config.patience_epoch or best_val_score_cnt >= 3: if adjust_lr_num >= config.adjust_lr_num: print("No optimization for a long time, auto stopping...") break print("No optimization for a long time, adjust lr...") # scheduler.step() last_improved_epoch = cur_epoch # 加上,不然会连续更新的 adjust_lr_num += 1 best_val_score_cnt = 0 scheduler.step() del model gc.collect() if fold_idx is not None: model_score[fold_idx] = best_val_score
def train(model, dataset, val_X, val_y, batch_size=512, epochs=500, epoch_show=1, weight_decay=1e-5, momentum=0.9): train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0) optimizer = Ranger(model.parameters(), weight_decay=weight_decay) # optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay) # optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=momentum, weight_decay=weight_decay) # min_lr = 1e-4 # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, # mode='min', # factor=0.1, # patience=10, # threshold=0, # min_lr=min_lr, # verbose=True) loss_fn = nn.MSELoss(reduction='mean') writer = SummaryWriter( 'runs/batchSize/linear_hidden={}_neurons={}_{}_batch={:04d}_bn={}_p={:.1f}_mom={}_l2={}' .format( len(model.features) - 1, model.features[-1], model.activation_fn_name, batch_size, model.bn, model.p, momentum, weight_decay)) best_val_mape = 100 for epoch in range(epochs): # if get_lr(optimizer) < 2 * min_lr: # break model.train() loss_epoch = 0 for i, data in enumerate(train_loader, 0): x, y = data out = model(x) loss_iter = loss_fn(y, out.squeeze()) optimizer.zero_grad() loss_iter.backward() optimizer.step() loss_epoch += loss_iter if epoch % epoch_show == 0 or epoch == epochs - 1: with torch.no_grad(): model.eval() pred = model(val_X) loss_val = loss_fn(val_y, pred.squeeze()) mape = mean_absolute_percentage_error(val_y, pred.squeeze()) # scheduler.step(best_val_mape) print( '\nEpoch {:03d}: loss_train={:.6f}, loss_val={:.6f}, val_mape={:.4f}, ' 'best_val_mape={:.4f}'.format(epoch, loss_epoch / (i + 1), loss_val, mape, best_val_mape), end=' ') if mape < best_val_mape: model_name = 'running_best_model.pt' print( 'Val_mape improved from {:.4f} to {:.4f}, saving model to {}' .format(best_val_mape, mape, model_name), end=' ') best_val_mape = mape torch.save(model.state_dict(), 'models/' + model_name) writer.add_scalar("loss/train", loss_epoch / (i + 1), epoch) writer.add_scalar("loss/val", loss_val, epoch) writer.add_scalar("mape/val", mape, epoch) torch.save(model.state_dict(), 'models/final_model.pt') writer.add_scalar("mape/best_val", best_val_mape) return writer
def train_fold(): #get arguments opts=get_args() #gpu selection os.environ["CUDA_VISIBLE_DEVICES"] = opts.gpu_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #instantiate datasets json_path=os.path.join(opts.path,'train.json') json=pd.read_json(json_path,lines=True) json=json[json.signal_to_noise > opts.noise_filter] ids=np.asarray(json.id.to_list()) error_weights=get_errors(json) error_weights=opts.error_alpha+np.exp(-error_weights*opts.error_beta) train_indices,val_indices=get_train_val_indices(json,opts.fold,SEED=2020,nfolds=opts.nfolds) _,labels=get_data(json) sequences=np.asarray(json.sequence) train_seqs=sequences[train_indices] val_seqs=sequences[val_indices] train_labels=labels[train_indices] val_labels=labels[val_indices] train_ids=ids[train_indices] val_ids=ids[val_indices] train_ew=error_weights[train_indices] val_ew=error_weights[val_indices] #train_inputs=np.stack([train_inputs],0) #val_inputs=np.stack([val_inputs,val_inputs2],0) dataset=RNADataset(train_seqs,train_labels,train_ids, train_ew, opts.path) val_dataset=RNADataset(val_seqs,val_labels, val_ids, val_ew, opts.path, training=False) dataloader = DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.workers) val_dataloader = DataLoader(val_dataset, batch_size=opts.batch_size*2, shuffle=False, num_workers=opts.workers) # print(dataset.data.shape) # print(dataset.bpps[0].shape) # exit() #checkpointing checkpoints_folder='checkpoints_fold{}'.format((opts.fold)) csv_file='log_fold{}.csv'.format((opts.fold)) columns=['epoch','train_loss', 'val_loss'] logger=CSVLogger(columns,csv_file) #build model and logger model=NucleicTransformer(opts.ntoken, opts.nclass, opts.ninp, opts.nhead, opts.nhid, opts.nlayers, opts.kmer_aggregation, kmers=opts.kmers,stride=opts.stride, dropout=opts.dropout).to(device) optimizer=Ranger(model.parameters(), weight_decay=opts.weight_decay) criterion=weighted_MCRMSE #lr_schedule=lr_AIAYN(optimizer,opts.ninp,opts.warmup_steps,opts.lr_scale) # Mixed precision initialization opt_level = 'O1' #model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model = nn.DataParallel(model) pretrained_df=pd.read_csv('pretrain.csv') #print(pretrained_df.epoch[-1]) model.load_state_dict(torch.load('pretrain_weights/epoch{}.ckpt'.format(int(pretrained_df.iloc[-1].epoch)))) pytorch_total_params = sum(p.numel() for p in model.parameters()) print('Total number of paramters: {}'.format(pytorch_total_params)) #distance_mask=get_distance_mask(107) #distance_mask=torch.tensor(distance_mask).float().to(device).reshape(1,107,107) #print("Starting training for fold {}/{}".format(opts.fold,opts.nfolds)) #training loop cos_epoch=int(opts.epochs*0.75)-1 lr_schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,(opts.epochs-cos_epoch)*len(dataloader)) for epoch in range(opts.epochs): model.train(True) t=time.time() total_loss=0 optimizer.zero_grad() train_preds=[] ground_truths=[] step=0 for data in dataloader: #for step in range(1): step+=1 #lr=lr_schedule.step() lr=get_lr(optimizer) #print(lr) src=data['data'].to(device) labels=data['labels'] bpps=data['bpp'].to(device) #print(bpps.shape[1]) # bpp_selection=np.random.randint(bpps.shape[1]) # bpps=bpps[:,bpp_selection] # src=src[:,bpp_selection] # print(bpps.shape) # print(src.shape) # exit() # print(bpps.shape) # exit() #src=mutate_rna_input(src,opts.nmute) #src=src.long()[:,np.random.randint(2)] labels=labels.to(device)#.float() output=model(src,bpps) ew=data['ew'].to(device) #print(output.shape) #print(labels.shape) loss=criterion(output[:,:68],labels,ew).mean() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) optimizer.step() optimizer.zero_grad() total_loss+=loss print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.3f} Lr:{:.6f} Time: {:.1f}" .format(epoch+1, opts.epochs, step+1, len(dataloader), total_loss/(step+1) , lr,time.time()-t),end='\r',flush=True) #total_loss/(step+1) #break if epoch > cos_epoch: lr_schedule.step() print('') train_loss=total_loss/(step+1) #recon_acc=np.sum(recon_preds==true_seqs)/len(recon_preds) torch.cuda.empty_cache() if (epoch+1)%opts.val_freq==0 and epoch > cos_epoch: #if (epoch+1)%opts.val_freq==0: val_loss=validate(model,device,val_dataloader,batch_size=opts.batch_size) to_log=[epoch+1,train_loss,val_loss,] logger.log(to_log) if (epoch+1)%opts.save_freq==0: save_weights(model,optimizer,epoch,checkpoints_folder) # if epoch == cos_epoch: # print('yes') get_best_weights_from_fold(opts.fold)
def train(fold_idx=None): # model = UNet(n_classes=1, n_channels=3) model = DeepLabV3_plus(num_classes=1, backbone='resnet', sync_bn=True) train_dataloader, valid_dataloader = get_trainval_dataloader() criterion = nn.BCEWithLogitsLoss() optimizer = Ranger(model.parameters(), lr=1e-3, weight_decay=0.0005) scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) best_val_score = 0 last_improved_epoch = 0 if fold_idx is None: print('start') model_save_path = os.path.join(config.dir_weight, '{}.bin'.format(config.save_model_name)) else: print('start fold: {}'.format(fold_idx + 1)) model_save_path = os.path.join( config.dir_weight, '{}_fold{}.bin'.format(config.save_model_name, fold_idx)) for cur_epoch in range(config.num_epochs): start_time = int(time.time()) model.train() print('epoch: ', cur_epoch + 1) cur_step = 0 for batch in train_dataloader: batch_x = batch['image'] batch_y = batch['mask'] batch_x, batch_y = batch_x.to(device), batch_y.to(device) optimizer.zero_grad() mask_pred = model(batch_x) train_loss = criterion(mask_pred, batch_y) train_loss.backward() optimizer.step() cur_step += 1 if cur_step % config.step_train_print == 0: train_acc = accuracy(mask_pred, batch_y) msg = 'the current step: {0}/{1}, train loss: {2:>5.2}, train acc: {3:>6.2%}' print( msg.format(cur_step, len(train_dataloader), train_loss.item(), train_acc[0].item())) val_miou = eval_net_unet_miou(model, valid_dataloader, device) val_score = val_miou if val_score > best_val_score: best_val_score = val_score torch.save(model.state_dict(), model_save_path) improved_str = '*' last_improved_epoch = cur_epoch else: improved_str = '' msg = 'the current epoch: {0}/{1}, val score: {3:>6.2%}, cost: {4}s {5}' end_time = int(time.time()) print( msg.format(cur_epoch + 1, config.num_epochs, val_score, end_time - start_time, improved_str)) if cur_epoch - last_improved_epoch > config.num_patience_epoch: print("No optimization for a long time, auto-stopping...") break scheduler_cosine.step() del model gc.collect() return best_val_score