def train_model(self, data_loader, criterion, optimizer): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' self.train() self.to(device) loss_avg = AverageMeter() acc_avg = AverageMeter() loss_avg.reset() acc_avg.reset() for label, image in tqdm(data_loader): image: torch.Tensor = image.to(device) label: torch.Tensor = label.to(device).float().unsqueeze(dim=1) pred: torch.Tensor = self.forward(image) loss: torch.Tensor = criterion(pred, label) optimizer.zero_grad() loss.backward() optimizer.step() loss_avg.update(loss.item()) pred = (pred + 0.5).long() num_correct = (pred == label).sum().item() acc_avg.update(num_correct / image.shape[0]) return loss_avg.avg, acc_avg.avg
def validate(self, epoch): self.model.eval() losses = AverageMeter() times = AverageMeter() losses.reset() times.reset() len_d = len(self.valid_loader) end = time.time() with torch.no_grad(): for i, data in enumerate(self.valid_loader): input, label = data input = [ele.to(self.device) for ele in input] label = [ele.to(self.device) for ele in label] output = self.model(input) bat_val_loss = self.loss_fn(output, label) bat_val_loss_avg = torch.mean(bat_val_loss) losses.update(bat_val_loss_avg.item()) times.update(time.time() - end) end = time.time() writer.add_scalar('valid_loss/batch_loss', bat_val_loss_avg, epoch * len_d + i) print('epoch %d, %d/%d, validation loss: %f, time estimated: %.2f seconds' % (epoch, i + 1, len_d, bat_val_loss_avg, times.avg * len_d), end='\r') print("\n") writer.add_scalar('valid_loss/valid_loss', losses.avg, epoch) if losses.avg < self.min_loss: self.early_stop_count = 0 self.min_loss = losses.avg torch.save(self.model.state_dict(),self.output_path+"/model.epoch%d"%epoch) print("Saved new model") else: self.early_stop_count += 1
def train(self, epoch): losses = AverageMeter() times = AverageMeter() losses.reset() times.reset() self.model.train() len_d = len(self.train_loader) end = time.time() for i, data in enumerate(self.train_loader): input, label = data input = [ele.to(self.device) for ele in input] label = [ele.to(self.device) for ele in label] output = self.model(input) bat_loss = self.loss_fn(output, label) bat_loss_avg = torch.mean(bat_loss) losses.update(bat_loss_avg.item()) self.optimizer.zero_grad() bat_loss_avg.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() times.update(time.time() - end) end = time.time() writer.add_scalar('train_loss/batch_loss', bat_loss_avg, epoch * len_d + i) print('epoch %d, %d/%d, training loss: %f, time estimated: %.2f seconds' % (epoch, i + 1, len_d, bat_loss_avg, times.avg * len_d), end='\r') self.scheduler.step() print("\n") writer.add_scalar('train_loss/train_loss', losses.avg, epoch)
def train(self, epoch): losses = AverageMeter() times = AverageMeter() losses.reset() times.reset() self.model.train() len_d = len(self.train_loader) init_time = time.time() end = init_time for i, data in enumerate(self.train_loader): input, label = data output = self.model(input) loss = self.loss_fn(output, label) loss_avg = torch.mean(loss) losses.update(loss_avg.item()) self.optimizer.zero_grad() loss_avg.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() times.update(time.time() - end) end = time.time() print( 'epoch %d, %d/%d, training loss: %f, time estimated: %.2f/%.2f seconds' % (epoch, i + 1, len_d, losses.avg, end - init_time, times.avg * len_d), end='\r') print("\n")
def test(self): times = AverageMeter() times.reset() len_d = len(self.test_loader) end = time.time() with torch.no_grad(): for i, data in enumerate(self.test_loader): input, infdat = data input = [ele.to(self.device) for ele in input] output = self.model(input) audio_out = output[0].squeeze().cpu().detach().numpy() out_path = self.output_path + '/out' if not os.path.exists(out_path): os.makedirs(out_path) fn = out_path + '/' + infdat[0][0] audio_out = audio_out.squeeze().T sf.write(fn, audio_out, self.feature_options.sampling_rate, subtype='PCM_16') times.update(time.time() - end) end = time.time() print('%d/%d, time estimated: %.2f seconds' % (i + 1, len_d, times.avg * len_d), end='\r') print("\n")
def _train_one_epoch(self, epoch): self.model.train() loss_meter = AverageMeter() time_meter = TimeMeter() for bid, (video, video_mask, words, word_mask, label, scores, scores_mask, id2pos, node_mask, adj_mat) in enumerate(self.train_loader, 1): self.optimizer.zero_grad() model_input = { 'frames': video.cuda(), 'frame_mask': video_mask.cuda(), 'words': words.cuda(), 'word_mask': word_mask.cuda(), 'label': scores.cuda(), 'label_mask': scores_mask.cuda(), 'gt': label.cuda(), 'node_pos': id2pos.cuda(), 'node_mask': node_mask.cuda(), 'adj_mat': adj_mat.cuda() } predict_boxes, loss, _, _, _ = self.model(**model_input) loss = torch.mean(loss) self.optimizer.backward(loss) self.optimizer.step() self.num_updates += 1 curr_lr = self.lr_scheduler.step_update(self.num_updates) loss_meter.update(loss.item()) time_meter.update() if bid % self.args.display_n_batches == 0: logging.info('Epoch %d, Batch %d, loss = %.4f, lr = %.5f, %.3f seconds/batch' % ( epoch, bid, loss_meter.avg, curr_lr, 1.0 / time_meter.avg )) loss_meter.reset()
def val(args, model=None, current_epoch=0): top1 = AverageMeter() top5 = AverageMeter() top1.reset() top5.reset() if model is None: model = get_model(args) model.eval() _, val_loader = data_loader(args, test_path=True) save_atten = SAVE_ATTEN(save_dir=args.save_atten_dir) global_counter = 0 prob = None gt = None for idx, dat in tqdm(enumerate(val_loader)): img_path, img, label_in = dat global_counter += 1 if args.tencrop == 'True': bs, ncrops, c, h, w = img.size() img = img.view(-1, c, h, w) label_input = label_in.repeat(10, 1) label = label_input.view(-1) else: label = label_in img, label = img.cuda(), label.cuda() img_var, label_var = Variable(img), Variable(label) logits = model(img_var, label_var) logits0 = logits[0] logits0 = F.softmax(logits0, dim=1) if args.tencrop == 'True': logits0 = logits0.view(bs, ncrops, -1).mean(1) # Calculate classification results prec1_1, prec5_1 = Metrics.accuracy(logits0.cpu().data, label_in.long(), topk=(1, 5)) # prec3_1, prec5_1 = Metrics.accuracy(logits[1].data, label.long(), topk=(1,5)) top1.update(prec1_1[0], img.size()[0]) top5.update(prec5_1[0], img.size()[0]) # save_atten.save_heatmap_segmentation(img_path, np_last_featmaps, label.cpu().numpy(), # save_dir='./save_bins/heatmaps', size=(0,0), maskedimg=True) # np_last_featmaps = logits[2].cpu().data.numpy() np_last_featmaps = logits[-1].cpu().data.numpy() np_scores, pred_labels = torch.topk(logits0, k=args.num_classes, dim=1) pred_np_labels = pred_labels.cpu().data.numpy() save_atten.save_top_5_pred_labels(pred_np_labels[:, :5], img_path, global_counter) # pred_np_labels[:,0] = label.cpu().numpy() #replace the first label with gt label # save_atten.save_top_5_atten_maps(np_last_featmaps, pred_np_labels, img_path) print('Top1:', top1.avg, 'Top5:', top5.avg)
def validate(self, epoch): self.model.eval() losses = AverageMeter() times = AverageMeter() losses.reset() times.reset() len_d = len(self.valid_loader) end = time.time() for i, data in enumerate(self.valid_loader): begin = time.time() input, label = data if torch.sum(label[0]) < 1: continue output = self.model(input) loss = self.loss_fn(output, label) loss_avg = torch.mean(loss) losses.update(loss_avg.item()) times.update(time.time() - end) end = time.time() print( 'epoch %d, %d/%d, validation loss: %f, time estimated: %.2f seconds' % (epoch, i + 1, len_d, losses.avg, times.avg * len_d), end='\r') print("\n") if losses.avg < self.min_loss: self.early_stop_count = 0 self.min_loss = losses.avg torch.save(self.model, self.output_path + "/model.epoch%d" % epoch) print("Saved new model") else: self.early_stop_count += 1
def train(epoch): net.train() total_loss = AverageMeter() epoch_loss_stats = AverageMeter() start_time = time.time() bar = tqdm(enumerate(train_loader)) for batch_idx, (inputs, labels) in bar: inputs = Variable(inputs) labels = Variable(labels) if args.cuda: inputs = inputs.cuda() labels = labels.cuda() optimizer.zero_grad() outputs = net(inputs) outputs = F.log_softmax(outputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss.update(loss.data[0], inputs.size(0)) epoch_loss_stats.update(loss.data[0], inputs.size(0)) if args.visdom is not None: cur_iter = batch_idx + (epoch - 1) * len(train_loader) vis.plot_line('train_plt', X=torch.ones((1, )).cpu() * cur_iter, Y=loss.data.cpu(), update='append') if batch_idx % args.backup_iters == 0: filename = 'texture_{0}_snapshot.pth'.format(args.split) filename = osp.join(args.save_folder, filename) state_dict = net.state_dict() torch.save(state_dict, filename) optim_filename = 'texture_{0}_optim.pth'.format(args.split) optim_filename = osp.join(args.save_folder, optim_filename) state_dict = optimizer.state_dict() torch.save(state_dict, optim_filename) if batch_idx % args.log_interval == 0: elapsed_time = time.time() - start_time bar.set_description('[{:5d}] ({:5d}/{:5d}) | ms/batch {:.6f} |' ' loss {:.6f} | lr {:.7f}'.format( epoch, batch_idx, len(train_loader), elapsed_time * 1000, total_loss.avg, optimizer.param_groups[0]['lr'])) total_loss.reset() start_time = time.time() epoch_total_loss = epoch_loss_stats.avg if args.visdom is not None: vis.plot_line('train_epoch_plt', X=torch.ones((1, )).cpu() * epoch, Y=torch.ones((1, )).cpu() * epoch_total_loss, update='append') return epoch_total_loss
def train(epoch, relative_age=True): net.train() total_loss = AverageMeter() epoch_loss_stats = AverageMeter() time_stats = AverageMeter() loss = 0 optimizer.zero_grad() for (batch_idx, (imgs, bone_ages, genders, chronological_ages, _)) in enumerate(train_loader): imgs = imgs.to(device) bone_ages = bone_ages.to(device) genders = genders.to(device) chronological_ages = chronological_ages.to(device) if relative_age: relative_ages = chronological_ages.squeeze(1) - bone_ages start_time = time.time() outputs = net(imgs, genders, chronological_ages) if relative_age: loss = criterion(outputs.squeeze(), relative_ages) else: loss = criterion(outputs.squeeze(), bone_ages) loss.backward() optimizer.step() loss = metric_average(loss.item(), 'loss') time_stats.update(time.time() - start_time, 1) total_loss.update(loss, 1) epoch_loss_stats.update(loss, 1) optimizer.zero_grad() if (batch_idx % args.log_interval == 0) and args.rank == 0: elapsed_time = time_stats.avg print(' [{:5d}] ({:5d}/{:5d}) | ms/batch {:.4f} |' ' loss {:.6f} | avg loss {:.6f} | lr {:.7f}'.format( epoch, batch_idx, len(train_loader), elapsed_time * 1000, total_loss.avg, epoch_loss_stats.avg, optimizer.param_groups[0]['lr'])) total_loss.reset() epoch_total_loss = epoch_loss_stats.avg args.resume_iter = 0 if args.rank == 0: filename = 'boneage_bonet_snapshot.pth' filename = osp.join(args.save_folder, filename) torch.save(net.state_dict(), filename) optim_filename = 'boneage_bonet_optim.pth' optim_filename = osp.join(args.save_folder, optim_filename) torch.save(optimizer.state_dict(), optim_filename) return epoch_total_loss
def train(epoch): net.train() total_loss = AverageMeter() # total_loss = 0 epoch_loss_stats = AverageMeter() # epoch_total_loss = 0 start_time = time.time() for i_batch, sample_batched in enumerate(train_loader): im = Variable(sample_batched[0]) label = Variable(sample_batched[1]) if args.cuda: im = im.cuda() label = label.cuda() optimizer.zero_grad() out_masks = net(im, label) out_masks = out_masks.cuda() loss = criterion(out_masks, label) loss.backward() optimizer.step() total_loss.update(loss.data[0], im.size(0)) epoch_loss_stats.update(loss.data[0], im.size(0)) if i_batch % args.backup_iters == 0: filename = 'casenet_{0}_{1}_snapshot.pth'.format( args.dataset, args.split) filename = osp.join(args.save_folder, filename) state_dict = net.state_dict() torch.save(state_dict, filename) optim_filename = 'casenet_{0}_{1}_optim.pth'.format( args.dataset, args.split) optim_filename = osp.join(args.save_folder, optim_filename) state_dict = optimizer.state_dict() torch.save(state_dict, optim_filename) if i_batch % args.log_interval == 0: elapsed_time = time.time() - start_time # cur_loss = total_loss / args.log_interval print('[{:5d}] ({:5d}/{:5d}) | ms/batch {:.6f} |' ' loss {:.6f} | lr {:.7f}'.format(epoch, i_batch, len(train_loader), elapsed_time * 1000, total_loss.avg, scheduler.get_lr()[0])) total_loss.reset() start_time = time.time() epoch_total_loss = epoch_loss_stats.avg return epoch_total_loss
def test(model, test_loader, epoch, margin, threshlod, is_cuda=True, log_interval=1000): model.eval() test_loss = AverageMeter() accuracy = 0 num_p = 0 total_num = 0 batch_num = len(test_loader) for batch_idx, (data_a, data_p, data_n, target) in enumerate(test_loader): if is_cuda: data_a = data_a.cuda() data_p = data_p.cuda() data_n = data_n.cuda() target = target.cuda() data_a = Variable(data_a, volatile=True) data_p = Variable(data_p, volatile=True) data_n = Variable(data_n, volatile=True) target = Variable(target) out_a = model(data_a) out_p = model(data_p) out_n = model(data_n) loss = F.triplet_margin_loss(out_a, out_p, out_n, margin) dist1 = F.pairwise_distance(out_a, out_p) dist2 = F.pairwise_distance(out_a, out_n) #print('dist1', dist1) #print('dist2',dist2) #print('threshlod', threshlod) num = ((dist1 < threshlod).sum() + (dist2 > threshlod).sum()).data[0] num_p += num num_p = 1.0 * num_p total_num += data_a.size()[0] * 2 #print('num--num_p -- total', num, num_p , total_num) test_loss.update(loss.data[0]) if (batch_idx + 1) % log_interval == 0: accuracy_tmp = num_p / total_num print('Test- Epoch {:04d}\tbatch:{:06d}/{:06d}\tAccuracy:{:.04f}\tloss:{:06f}'\ .format(epoch, batch_idx+1, batch_num, accuracy_tmp, test_loss.avg)) test_loss.reset() accuracy = num_p / total_num return accuracy
def run(self, epochs, lr_decay=False, mixup=False): train_loss_meter = AverageMeter() min_val_loss = 1e8 tolerance = self.early_stopping_tolerance self.scheduler_setup(lr_decay) for epoch in range(1, epochs + 1): losses = train( self.model, self.device, self.train_loader, self.optimizer, self.criterion, epoch, mixup, train_loss_meter, ) # Returns loss per batch self.train_loss.extend(losses) val_loss, val_accuracy = test( self.model, self.device, self.val_loader, self.criterion) # Returns loss/accuracy per epoch self.test_loss.append(val_loss) self.test_accuracy.append(val_accuracy) self.scheduler_step(lr_decay, epoch) print( f"Epoch {epoch} \t" f"train_loss: {train_loss_meter.average:.6f}" f"\tval_loss: {val_loss:.4f}\tval_accuracy: {val_accuracy * 100:.2f}%" ) train_loss_meter.reset() if val_loss < min_val_loss: min_val_loss = val_loss tolerance = ( self.early_stopping_tolerance ) # Reset the tolerance because validation loss improved else: if epoch > 20: # Early stopping doesn't start before 20 epochs tolerance -= 1 if tolerance == 0: # Early stopping the training process print(f"\nEarly stopping. Val loss did not improve for " f"{self.early_stopping_tolerance} consecutive epochs") break
def do_test(cfg, model, test_loader, experiment_name): test_acc_meter = AverageMeter() test_acc_meter.reset() device = cfg.MODEL.DEVICE logger = logging.getLogger('{}.test'.format(cfg.PROJECT.NAME)) logger.info("Enter Image Classification Test") if device: if torch.cuda.device_count() > 1: print('Using {} GPUs for test'.format( torch.cuda.device_count())) model = nn.DataParallel(model) model.to(device) # generate result csv output_dir = os.path.join(cfg.MODEL.OUTPUT_PATH, experiment_name) result_path = os.path.join( output_dir, experiment_name + '_' + 'test_result.csv') with open(result_path, 'w') as f: f.write("file_name,label,predictive_label") model.eval() for iteration, (img, vid, vname) in enumerate(test_loader): with torch.no_grad(): img = img.to(device) vid = torch.tensor(vid) target = vid.to(device) score = model(img) p_label = score.max(1)[1] acc = (score.max(1)[1] == target).float().mean() test_acc_meter.update(acc, 1) logger.info( "Iteration[{}/{}], Test_Acc: {:.3f}" .format((iteration + 1), len(test_loader), test_acc_meter.avg)) with open(result_path, 'a+') as f: for i in range(len(vid)): name = list(vname)[i] label = str(vid[i].item()) p_label_ = str(p_label[i].item()) f.write('\n') f.write(name +','+label+ ',' +p_label_)
def process_epoch(self, train): if train: data_loader = self.train_loader else: data_loader = self.val_loader loss_unatt_agg = AverageMeter() acc_unatt_agg = AverageMeter() loss_att_agg = AverageMeter() acc_att_agg = AverageMeter() for i, (input, target) in enumerate(data_loader): if self.target: target = self.target * target.new_ones(target.size()) input, target = input.to(self.device), target.to(self.device) with torch.no_grad(): output = self.classifier(input) los_unatt = self.criterion(output, target) loss_unatt_agg.update(los_unatt.item(), input.size(0)) acc_unatt_agg.update(accuracy(output, target).item(), input.size(0)) with torch.set_grad_enabled(train): input_att, _ = self.framing(input=input) output_att = self.classifier(input_att) loss_att = self.criterion(output_att, target) loss_att_agg.update(loss_att.item(), input.size(0)) acc_att_agg.update(accuracy(output_att, target).item(), input.size(0)) if train: self.optimizer.zero_grad() framing_loss = loss_att if self.target is not None else -loss_att framing_loss.backward() self.optimizer.step() self.step += input.size(0) if train: if (i + 1) % self.args.print_freq == 0: self.logger.log_kv([ ('unatt_loss', loss_unatt_agg.avg), ('att_loss', loss_att_agg.avg), ('unatt_acc', acc_unatt_agg.avg), ('att_acc', acc_att_agg.avg), ], prefix='train', step=self.step, write_to_tb=True) loss_unatt_agg.reset() loss_att_agg.reset() acc_unatt_agg.reset() acc_att_agg.reset() else: if i + 1 == len(data_loader): self.logger.log_kv([ ('unatt_loss', loss_unatt_agg.avg), ('att_loss', loss_att_agg.avg), ('unatt_acc', acc_unatt_agg.avg), ('att_acc', acc_att_agg.avg), ], prefix='eval', step=self.step, write_to_tb=True, write_to_file=True)
def train(epoch): """Train of the net.""" net.train() total_loss = AverageMeter() epoch_loss_stats = AverageMeter() start_time = time.time() bar = tqdm(enumerate(train_loader)) for batch_idx, sample in bar: optimizer.zero_grad() # TODO: Call the train routine for the net # outputs, label = routines.train_routine(sample) loss = criterion(outputs, label) loss.backward() optimizer.step() total_loss.update(loss.data, n=outputs.size(0)) epoch_loss_stats.update(loss.data, n=outputs.size(0)) if batch_idx % cfg.TRAIN.BACKUP_ITERS == 0: filename = '{0}_snapshot.pth'.format(cfg.DATASET.SPLIT) filename = osp.join(cfg.OUTPUT_DIR, filename) state_dict = net.state_dict() torch.save(state_dict, filename) optim_filename = '{0}_optim.pth'.format(cfg.DATASET.SPLIT) optim_filename = osp.join(cfg.OUTPUT_DIR, optim_filename) state_dict = optimizer.state_dict() torch.save(state_dict, optim_filename) if batch_idx % cfg.TRAIN.LOG_INTERVAL == 0: elapsed_time = time.time() - start_time bar.set_description('[{:5d}] ({:5d}/{:5d}) | ms/batch {:.6f} |' ' loss {:.6f} | lr {:.7f}'.format( epoch, batch_idx, len(train_loader), elapsed_time * 1000, total_loss.avg, optimizer.param_groups[0]['lr'])) total_loss.reset() start_time = time.time() epoch_total_loss = epoch_loss_stats.avg return epoch_total_loss
def validate(self, epoch): self.model.eval() losses = AverageMeter() times = AverageMeter() losses.reset() times.reset() len_d = len(self.valid_loader) init_time = time.time() end = init_time for i, data in enumerate(self.valid_loader): begin = time.time() input, label = data if torch.sum(label[0]) < 1: continue output = self.model(input) loss = self.loss_fn(output, label) loss_avg = torch.mean(loss) losses.update(loss_avg.item()) times.update(time.time() - end) end = time.time() print( 'epoch %d, %d/%d, validation loss: %f, time estimated: %.2f/%.2f seconds' % (epoch, i + 1, len_d, losses.avg, end - init_time, times.avg * len_d), end='\r') print("\n") if losses.avg < self.min_loss: self.early_stop_count = 0 self.min_loss = losses.avg saved_dict = { 'model': self.model.state_dict(), 'epoch': epoch, 'optimizer': self.optimizer, 'cv_loss': self.min_loss, "early_stop_count": self.early_stop_count } torch.save(saved_dict, self.output_path + "/final.mdl") print("Saved new model") else: self.early_stop_count += 1
def validate_model(self, data_loader, criterion): with torch.no_grad(): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' self.eval() self.to(device) loss_avg = AverageMeter() acc_avg = AverageMeter() loss_avg.reset() acc_avg.reset() for label, image in tqdm(data_loader): image: torch.Tensor = image.to(device) label: torch.Tensor = label.to(device).float().unsqueeze(dim=1) pred: torch.Tensor = self.forward(image) loss: torch.Tensor = criterion(pred, label) pred = pred >= 0.5 loss_avg.update(loss.item()) num_correct = (pred == label).sum().item() acc_avg.update(num_correct / image.shape[0]) return loss_avg.avg, acc_avg.avg
class Trainer(): def __init__(self, disp_model, pose_model, optimizer, opt): self.disp_model = disp_model self.pose_model = pose_model self.optimizer = optimizer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.losses = AverageMeter() def train(self, trainloader, epoch, opt): self.losses.reset() self.data_time.reset() self.batch_time.reset() end = time.time() self.disp_model.train() self.pose_model.train() for i, data in enumerate(trainloader, 0): self.optimizer.zero_grad() if opt.cuda: target_imgs, ref_imgs, intrinsics, intrinsics_inv = data target_imgs = Variable(target_imgs.cuda(async=True)) ref_imgs = [Variable(img.cuda(async=True)) for img in ref_imgs] intrinsics = Variable(intrinsics.cuda(async=True)) intrinsics_inv = Variable(intrinsics_inv.cuda(async=True)) self.data_time.update(time.time() - end) disparities = self.disp_model(target_imgs) depths = [1 / disp for disp in disparities] explainability_mask, pose = self.pose_model(target_imgs, ref_imgs) photoloss = photometric_reconstruction_loss( target_imgs, ref_imgs, intrinsics, intrinsics_inv, depths, explainability_mask, pose, opt.rotation_mode, opt.padding_mode) exploss = explainability_loss(explainability_mask) smoothloss = smooth_loss(disparities) totalloss = opt.photo_loss_weight * photoloss + opt.mask_loss_weight * exploss + opt.smooth_loss_weight * smoothloss totalloss.backward() self.optimizer.step() inputs_size = intrinsics.size(0) self.losses.update(totalloss.data[0], inputs_size) self.batch_time.update(time.time() - end) end = time.time() if i % opt.printfreq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.avg:.3f} ({batch_time.sum:.3f})\t' 'Data {data_time.avg:.3f} ({data_time.sum:.3f})\t' 'Loss {loss.avg:.3f}\t'.format( epoch, i, len(trainloader), batch_time=self.batch_time, data_time=self.data_time, loss=self.losses)) sys.stdout.flush()
def validate(self, epoch): self.model.eval() losses = AverageMeter() times = AverageMeter() losses_snr = AverageMeter() losses.reset() times.reset() losses_snr.reset() len_d = len(self.valid_loader) end = time.time() with torch.no_grad(): for i, data in enumerate(self.valid_loader): begin = time.time() input, label = data input = [ele.to(self.device) for ele in input] label = [ele.to(self.device) for ele in label] out_spec, out_wav = self.model(input) loss_snr = self.loss_fn(out_wav, label) loss = -loss_snr loss_avg = torch.mean(loss) losses.update(loss_avg.item()) losses_snr_avg = torch.mean(loss_snr) losses_snr.update(losses_snr_avg.item()) times.update(time.time() - end) end = time.time() writer.add_scalar('valid_loss/loss(snr)', losses.avg, epoch * len_d + i + 1) print( 'epoch %d, %d/%d, validation loss: %f, time estimated: %.2f seconds' % (epoch, i + 1, len_d, losses.avg, times.avg * len_d), end='\r') print("\n") if losses.avg < self.min_loss: self.early_stop_count = 0 self.min_loss = losses.avg torch.save(self.model, self.output_path + "/model.epoch%d" % epoch) print("Saved new model") else: self.early_stop_count += 1
def train(model, num_epochs, resume_epoch): loss_meter = AverageMeter() rpn_meter = AverageMeter() frcnn_meter = AverageMeter() for epoch in tqdm(range(num_epochs), total=num_epochs): pbar = tqdm(voc_loader, total=len(voc_loader), leave=True) for image, target in pbar: image = Variable(image).cuda(async=True) target = target.squeeze(0).numpy() rpn_cls_probs, rpn_bbox_deltas, pred_label, pred_bbox_deltas = frcnn(image) proposal_boxes, _ = frcnn.get_rpn_proposals() if len(proposal_boxes) == 0: continue rpn_labels, rpn_bbox_targets, rpn_batch_indices = frcnn.get_rpn_targets(target) detector_labels, delta_boxes, clf_batch_indices = frcnn.get_detector_targets(target) rpn_loss = criterion1(rpn_cls_probs, rpn_bbox_deltas, Variable(rpn_labels, requires_grad=False).cuda(), Variable(rpn_bbox_targets, requires_grad=False).cuda(), rpn_batch_indices.cuda()) frcnn_loss = criterion2(pred_label, pred_bbox_deltas, Variable(detector_labels, requires_grad=False).cuda(), Variable(delta_boxes, requires_grad=False).cuda(), clf_batch_indices.cuda()) total_loss = rpn_loss + frcnn_loss rpn_meter.update(rpn_loss.data[0]) frcnn_meter.update(frcnn_loss.data[0]) loss_meter.update(total_loss.data[0]) pbar.set_description(desc='loss {:.4f} | rpn loss {:.4f} | frcnn loss {:.4f}'.format(loss_meter.avg, rpn_meter.avg, frcnn_meter.avg)) total_loss.backward() optimizer.step() if (epoch + 1) % CHECKPOINT_RATE == 0: save_checkpoint(frcnn.state_dict(), optimizer.state_dict(), os.path.join(WEIGHT_DIR, "{}_{:.1e}_{:.4f}.pt".format(epoch + 1 + resume_epoch, LEARNING_RATE, loss_meter.avg))) loss_meter.reset() rpn_meter.reset() frcnn_meter.reset()
def train(model, img_encoder, normalize, base_loader, optimizer, criterion, epoch, total_epoch, device, logger, nodes, desc_embeddings, id_to_class_name, classFile_to_wikiID): batch_time = AverageMeter() # forward prop. + back prop. time data_time = AverageMeter() # data loading time losses = AverageMeter() # loss model.train() img_encoder.eval() start = time.time() for i, (imgs, labels, sp_labels) in enumerate(base_loader): data_time.update(time.time() - start) imgs = imgs.to(device) labels = labels.to(device) sp_labels = sp_labels.to(device) corr_nodeIndexs = find_nodeIndex_by_imgLabels(nodes, labels, id_to_class_name, classFile_to_wikiID) _, class_outputs, sp_outputs, att_features, corr_features = model( imgs, desc_embeddings, corr_nodeIndexs, norm=normalize) loss = criterion(class_outputs, sp_outputs, labels, sp_labels, att_features, corr_features) optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss.item()) batch_time.update(time.time() - start) start = time.time() if i % 30 == 29: # print every 30 mini-batches logger.info( f'[{epoch:3d}/{total_epoch}|{i+1:3d}, ' f'{len(base_loader)}] batch_time: {batch_time.avg:.2f} ' f'data_time: {data_time.avg:.2f} loss: {losses.avg:.3f}') batch_time.reset() data_time.reset() losses.reset()
def train(model, normalize, base_loader, optimizer, criterion, epoch, total_epoch, device, logger): batch_time = AverageMeter() # forward prop. + back prop. time data_time = AverageMeter() # data loading time losses = AverageMeter() # loss model.train() start = time.time() for i, (imgs, labels, sp_labels) in enumerate(base_loader): data_time.update(time.time() - start) imgs = imgs.to(device) labels = labels.to(device) sp_labels = sp_labels.to(device) _, class_outputs, sp_outputs = model(imgs, norm=normalize) loss = criterion(class_outputs, sp_outputs, labels, sp_labels) optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss.item()) batch_time.update(time.time() - start) start = time.time() if i % 30 == 29: # print every 30 mini-batches logger.info( f'[{epoch:3d}/{total_epoch}|{i+1:3d}, ' f'{len(base_loader)}] batch_time: {batch_time.avg:.2f} ' f'data_time: {data_time.avg:.2f} loss: {losses.avg:.3f}') batch_time.reset() data_time.reset() losses.reset()
def train(args, net, optimizer, criterion, scheduler): log_file = open(args.save_root + "training.log", "w", 1) log_file.write(args.exp_name + '\n') for arg in vars(args): print(arg, getattr(args, arg)) log_file.write(str(arg) + ': ' + str(getattr(args, arg)) + '\n') log_file.write(str(net)) net.train() # loss counters batch_time = AverageMeter() losses = AverageMeter() loc_losses = AverageMeter() cls_losses = AverageMeter() print('Loading Dataset...') train_dataset = UCF24Detection(args.data_root, args.train_sets, SSDAugmentation(args.ssd_dim, args.means), AnnotationTransform(), input_type=args.input_type) val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means), AnnotationTransform(), input_type=args.input_type, full_test=False) epoch_size = len(train_dataset) // args.batch_size print('Training SSD on', train_dataset.name) if args.visdom: import visdom viz = visdom.Visdom() viz.port = 8097 viz.env = args.exp_name # initialize visdom loss plot lot = viz.line(X=torch.zeros((1, )).cpu(), Y=torch.zeros((1, 6)).cpu(), opts=dict(xlabel='Iteration', ylabel='Loss', title='Current SSD Training Loss', legend=[ 'REG', 'CLS', 'AVG', 'S-REG', ' S-CLS', ' S-AVG' ])) # initialize visdom meanAP and class APs plot legends = ['meanAP'] for cls in CLASSES: legends.append(cls) val_lot = viz.line(X=torch.zeros((1, )).cpu(), Y=torch.zeros((1, args.num_classes)).cpu(), opts=dict(xlabel='Iteration', ylabel='Mean AP', title='Current SSD Validation mean AP', legend=legends)) batch_iterator = None train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers, shuffle=False, collate_fn=detection_collate, pin_memory=True) itr_count = 0 torch.cuda.synchronize() t0 = time.perf_counter() for iteration in range(args.max_iter + 1): if (not batch_iterator) or (iteration % epoch_size == 0): # create batch iterator batch_iterator = iter(train_data_loader) # load train data images, targets, img_indexs = next(batch_iterator) if args.cuda: images = Variable(images.cuda()) targets = [ Variable(anno.cuda(), volatile=True) for anno in targets ] else: images = Variable(images) targets = [Variable(anno, volatile=True) for anno in targets] # forward out = net(images) # backprop optimizer.zero_grad() loss_l, loss_c = criterion(out, targets) loss = loss_l + loss_c loss.backward() optimizer.step() scheduler.step() loc_loss = loss_l.data[0] conf_loss = loss_c.data[0] # print('Loss data type ',type(loc_loss)) loc_losses.update(loc_loss) cls_losses.update(conf_loss) losses.update((loc_loss + conf_loss) / 2.0) if iteration % args.print_step == 0 and iteration > 0: if args.visdom: losses_list = [ loc_losses.val, cls_losses.val, losses.val, loc_losses.avg, cls_losses.avg, losses.avg ] viz.line(X=torch.ones((1, 6)).cpu() * iteration, Y=torch.from_numpy( np.asarray(losses_list)).unsqueeze(0).cpu(), win=lot, update='append') torch.cuda.synchronize() t1 = time.perf_counter() batch_time.update(t1 - t0) print_line = 'Itration {:06d}/{:06d} loc-loss {:.3f}({:.3f}) cls-loss {:.3f}({:.3f}) ' \ 'average-loss {:.3f}({:.3f}) Timer {:0.3f}({:0.3f})'.format( iteration, args.max_iter, loc_losses.val, loc_losses.avg, cls_losses.val, cls_losses.avg, losses.val, losses.avg, batch_time.val, batch_time.avg) torch.cuda.synchronize() t0 = time.perf_counter() log_file.write(print_line + '\n') print(print_line) # if args.visdom and args.send_images_to_visdom: # random_batch_index = np.random.randint(images.size(0)) # viz.image(images.data[random_batch_index].cpu().numpy()) itr_count += 1 if itr_count % args.loss_reset_step == 0 and itr_count > 0: loc_losses.reset() cls_losses.reset() losses.reset() batch_time.reset() print('Reset accumulators of ', args.exp_name, ' at', itr_count * args.print_step) itr_count = 0 if (iteration % args.eval_step == 0 or iteration == 5000) and iteration > 0: torch.cuda.synchronize() tvs = time.perf_counter() print('Saving state, iter:', iteration) torch.save( net.state_dict(), args.save_root + 'ssd300_ucf24_' + repr(iteration) + '.pth') net.eval() # switch net to evaluation mode mAP, ap_all, ap_strs = validate(args, net, val_data_loader, val_dataset, iteration, iou_thresh=args.iou_thresh) for ap_str in ap_strs: print(ap_str) log_file.write(ap_str + '\n') ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n' print(ptr_str) log_file.write(ptr_str) if args.visdom: aps = [mAP] for ap in ap_all: aps.append(ap) viz.line(X=torch.ones((1, args.num_classes)).cpu() * iteration, Y=torch.from_numpy( np.asarray(aps)).unsqueeze(0).cpu(), win=val_lot, update='append') net.train() # Switch net back to training mode torch.cuda.synchronize() t0 = time.perf_counter() prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0 - tvs) print(prt_str) log_file.write(ptr_str) log_file.close()
def train(): """ Train the model using the parameters defined in the config file """ print('Initialising ...') cfg = TrainConfig() checkpoint_folder = 'checkpoints/{}/'.format(cfg.experiment_name) if not os.path.exists(checkpoint_folder): os.makedirs(checkpoint_folder) tb_folder = 'tb/{}/'.format(cfg.experiment_name) if not os.path.exists(tb_folder): os.makedirs(tb_folder) writer = SummaryWriter(logdir=tb_folder, flush_secs=30) model = ParrotModel().cuda().train() optimiser = AdamW(model.parameters(), lr=cfg.initial_lr, weight_decay=cfg.weight_decay) train_dataset = ParrotDataset(cfg.train_labels, cfg.mp3_folder) train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.workers, collate_fn=parrot_collate_function, pin_memory=True) val_dataset = ParrotDataset(cfg.val_labels, cfg.mp3_folder) val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, num_workers=cfg.workers, collate_fn=parrot_collate_function, shuffle=False, pin_memory=True) epochs = cfg.epochs init_loss, step = 0., 0 avg_loss = AverageMeter() print('Starting training') for epoch in range(epochs): loader_length = len(train_loader) epoch_start = time.time() for batch_idx, batch in enumerate(train_loader): optimiser.zero_grad() # VRAM control by skipping long examples if batch['spectrograms'].shape[-1] > cfg.max_time: continue # inference target = batch['targets'].cuda() model_input = batch['spectrograms'].cuda() model_output = model(model_input) # loss input_lengths = batch['input_lengths'].cuda() target_lengths = batch['target_lengths'].cuda() loss = ctc_loss(model_output, target, input_lengths, target_lengths) loss.backward() if epoch == 0 and batch_idx == 0: init_loss = loss # logging elapsed = time.time() - epoch_start progress = batch_idx / loader_length est = datetime.timedelta( seconds=int(elapsed / progress)) if progress > 0.001 else '-' avg_loss.update(loss) suffix = '\tloss {:.4f}/{:.4f}\tETA [{}/{}]'.format( avg_loss.avg, init_loss, datetime.timedelta(seconds=int(elapsed)), est) printProgressBar(batch_idx, loader_length, suffix=suffix, prefix='Epoch [{}/{}]\tStep [{}/{}]'.format( epoch, epochs, batch_idx, loader_length)) writer.add_scalar('Steps/train_loss', loss, step) # saving the model if step % cfg.checkpoint_every == 0: test_name = '{}/test_epoch{}.mp3'.format( checkpoint_folder, epoch) test_mp3_file(cfg.test_mp3, model, test_name) checkpoint_name = '{}/epoch_{}.pth'.format( checkpoint_folder, epoch) torch.save( { 'model': model.state_dict(), 'epoch': epoch, 'batch_idx': loader_length, 'step': step, 'optimiser': optimiser.state_dict() }, checkpoint_name) # validating if step % cfg.val_every == 0: val(model, val_loader, writer, step) model = model.train() step += 1 optimiser.step() # end of epoch print('') writer.add_scalar('Epochs/train_loss', avg_loss.avg, epoch) avg_loss.reset() test_name = '{}/test_epoch{}.mp3'.format(checkpoint_folder, epoch) test_mp3_file(cfg.test_mp3, model, test_name) checkpoint_name = '{}/epoch_{}.pth'.format(checkpoint_folder, epoch) torch.save( { 'model': model.state_dict(), 'epoch': epoch, 'batch_idx': loader_length, 'step': step, 'optimiser': optimiser.state_dict() }, checkpoint_name) # finished training writer.close() print('Training finished :)')
def train(): set_seed(seed=10) os.makedirs(args.save_root, exist_ok=True) # create model, optimizer and criterion model = SSD300(n_classes=len(label_map), device=device) biases = [] not_biases = [] for name, param in model.named_parameters(): if param.requires_grad: if name.endswith('.bias'): biases.append(param) else: not_biases.append(param) model = model.to(device) optimizer = torch.optim.SGD(params=[{ 'params': biases, 'lr': 2 * args.lr }, { 'params': not_biases }], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume is None: start_epoch = 0 else: checkpoint = torch.load(args.resume, map_location=device) start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print(f'Training will start at epoch {start_epoch}.') criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy, device=device, alpha=args.alpha) criterion = criterion.to(device) ''' scheduler = StepLR(optimizer=optimizer, step_size=20, gamma=0.5, last_epoch=start_epoch - 1, verbose=True) ''' # load data transform = Transform(size=(300, 300), train=True) train_dataset = VOCDataset(root=args.data_root, image_set=args.image_set, transform=transform, keep_difficult=True) train_loader = DataLoader(dataset=train_dataset, collate_fn=collate_fn, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True) losses = AverageMeter() for epoch in range(start_epoch, args.num_epochs): # decay learning rate at particular epochs if epoch in [120, 140, 160]: adjust_learning_rate(optimizer, 0.1) # train model model.train() losses.reset() bar = tqdm(train_loader, desc='Train the model') for i, (images, bboxes, labels, _) in enumerate(bar): images = images.to(device) bboxes = [b.to(device) for b in bboxes] labels = [l.to(device) for l in labels] predicted_bboxes, predicted_scores = model( images) # (N, 8732, 4), (N, 8732, num_classes) loss = criterion(predicted_bboxes, predicted_scores, bboxes, labels) optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss.item(), images.size(0)) if i % args.print_freq == args.print_freq - 1: bar.write(f'Average Loss: {losses.avg:.4f}') bar.write(f'Epoch: [{epoch + 1}|{args.num_epochs}] ' f'Average Loss: {losses.avg:.4f}') # adjust learning rate # scheduler.step() # save model state_dict = { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } save_path = os.path.join(args.save_root, 'ssd300.pth') torch.save(state_dict, save_path) if epoch % args.save_freq == args.save_freq - 1: shutil.copyfile( save_path, os.path.join(args.save_root, f'ssd300_epochs_{epoch + 1}.pth'))
class Trainer(object): # the most basic model def __init__(self, config, data_loader=None): self.config = config self.data_loader = data_loader # needed for VAE self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.optimizer = config.optimizer self.batch_size = config.batch_size self.diffLoss = L1Loss_mask() # custom module self.valmin_iter = 0 self.model_dir = 'logs/' + str(config.expnum) self.savename_G = '' self.decoder = GreedyDecoder(data_loader.labels) self.kt = 0 # used for Proportional Control Theory in BEGAN, initialized as 0 self.lb = 0.001 self.conv_measure = 0 # convergence measure self.dce_tr = AverageMeter() self.dce_val = AverageMeter() self.wer_tr = AverageMeter() self.cer_tr = AverageMeter() self.wer_val = AverageMeter() self.cer_val = AverageMeter() self.build_model() self.G.loss_stop = 100000 #self.get_weight_statistic() if self.config.gpu >= 0: self.G.cuda() self.ASR.cuda() if len(self.config.load_path) > 0: self.load_model() if config.mode == 'train': self.logFile = open(self.model_dir + '/log.txt', 'w') def zero_grad_all(self): self.G.zero_grad() def build_model(self): print('initialize enhancement model') self.G = stackedBRNN(I=self.config.nFeat, H=self.config.rnn_size, L=self.config.rnn_layers, rnn_type=supported_rnns[self.config.rnn_type]) print('load pre-trained ASR model') package_ASR = torch.load(self.config.ASR_path, map_location=lambda storage, loc: storage) self.ASR = DeepSpeech.load_model_package(package_ASR) # Weight initialization is done inside the module def load_model(self): print("[*] Load models from {}...".format(self.load_path)) postfix = '_valmin' paths = glob(os.path.join(self.load_path, 'G{}*.pth'.format(postfix))) paths.sort() if len(paths) == 0: print("[!] No checkpoint found in {}...".format(self.load_path)) assert (0), 'checkpoint not avilable' idxes = [ int(os.path.basename(path.split('.')[0].split('_')[-1])) for path in paths ] if self.config.start_iter < 0: self.config.start_iter = max(idxes) if (self.config.start_iter < 0): # if still 0, then raise error raise Exception( "start iter is still less than 0 --> probably try to load initial random model" ) if self.config.gpu < 0: #CPU map_location = lambda storage, loc: storage else: # GPU map_location = None # Ver2 print('Load models from ' + self.load_path + ', ITERATION = ' + str(self.config.start_iter)) self.G.load_state_dict( torch.load('{}/G{}_{}.pth'.format(self.load_path[:-1], postfix, self.config.start_iter), map_location=map_location)) print("[*] Model loaded") def train(self): # Setting optimizer_g = torch.optim.Adam(self.G.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) for iter in trange(self.config.start_iter, self.config.max_iter): # Train data_list = self.data_loader.next(cl_ny='ny', type='train') inputs, cleans, mask = _get_variable_nograd( data_list[0]), _get_variable_nograd( data_list[1]), _get_variable_nograd(data_list[2]) # forward outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function # backward self.zero_grad_all() dce.backward() optimizer_g.step() # log #pdb.set_trace() if (iter + 1) % self.config.log_iter == 0: str_loss = "[{}/{}] (train) DCE: {:.7f}".format( iter, self.config.max_iter, dce.data[0]) print(str_loss) self.logFile.write(str_loss + '\n') self.logFile.flush() if (iter + 1) % self.config.save_iter == 0: self.G.eval() # Measure performance on training subset self.dce_tr.reset() self.wer_tr.reset() self.cer_tr.reset() for _ in trange(0, len(self.data_loader.trsub_dl)): data_list = self.data_loader.next(cl_ny='ny', type='trsub') inputs, cleans, mask, targets, input_percentages, target_sizes = \ _get_variable_volatile(data_list[0]), _get_variable_volatile(data_list[1]), _get_variable_volatile(data_list[2]), \ data_list[3], data_list[4], data_list[5] outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function self.dce_tr.update(dce.data[0], nElement) # Greedy decodoing wer, cer, nWord, nChar = self.greedy_decoding( inputs, targets, input_percentages, target_sizes) self.wer_tr.update(wer, nWord) self.cer_tr.update(cer, nChar) str_loss = "[{}/{}] (training subset) DCE: {:.7f}".format( iter, self.config.max_iter, self.dce_tr.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (training subset) WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.wer_tr.avg * 100, self.cer_tr.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') # Measure performance on validation data self.dce_val.reset() self.wer_val.reset() self.cer_val.reset() for _ in trange(0, len(self.data_loader.val_dl)): data_list = self.data_loader.next(cl_ny='ny', type='val') inputs, cleans, mask, targets, input_percentages, target_sizes = \ _get_variable_volatile(data_list[0]), _get_variable_volatile(data_list[1]), _get_variable_volatile(data_list[2]), \ data_list[3], data_list[4], data_list[5] outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function self.dce_val.update(dce.data[0], nElement) # Greedy decodoing wer, cer, nWord, nChar = self.greedy_decoding( inputs, targets, input_percentages, target_sizes) self.wer_val.update(wer, nWord) self.cer_val.update(cer, nChar) str_loss = "[{}/{}] (validation) DCE: {:.7f}".format( iter, self.config.max_iter, self.dce_val.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (validation) WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.wer_val.avg * 100, self.cer_val.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') self.G.train() # end of validation self.logFile.flush() # Save model if (len(self.savename_G) > 0): # do not remove here if os.path.exists(self.savename_G): os.remove(self.savename_G) # remove previous model self.savename_G = '{}/G_{}.pth'.format(self.model_dir, iter) torch.save(self.G.state_dict(), self.savename_G) if (self.G.loss_stop > self.wer_val.avg): self.G.loss_stop = self.wer_val.avg savename_G_valmin_prev = '{}/G_valmin_{}.pth'.format( self.model_dir, self.valmin_iter) if os.path.exists(savename_G_valmin_prev): os.remove( savename_G_valmin_prev) # remove previous model print('save model for this checkpoint') savename_G_valmin = '{}/G_valmin_{}.pth'.format( self.model_dir, iter) copyfile(self.savename_G, savename_G_valmin) self.valmin_iter = iter def greedy_decoding(self, inputs, targets, input_percentages, target_sizes, transcript_prob=0.001): # unflatten targets split_targets = [] offset = 0 for size in target_sizes: split_targets.append(targets[offset:offset + size]) offset += size # step 1) Decoding to get wer & cer enhanced = self.G(inputs) prob = self.ASR(enhanced) prob = prob.transpose(0, 1) T = prob.size(0) sizes = input_percentages.mul_(int(T)).int() decoded_output, _ = self.decoder.decode(prob.data, sizes) target_strings = self.decoder.convert_to_strings(split_targets) we, ce, total_word, total_char = 0, 0, 0, 0 for x in range(len(target_strings)): decoding, reference = decoded_output[x][0], target_strings[x][0] nChar = len(reference) nWord = len(reference.split()) we_i = self.decoder.wer(decoding, reference) ce_i = self.decoder.cer(decoding, reference) we += we_i ce += ce_i total_word += nWord total_char += nChar if (random.uniform(0, 1) < transcript_prob): print('reference = ' + reference) print('decoding = ' + decoding) print('wer = ' + str(we_i / float(nWord)) + ', cer = ' + str(ce_i / float(nChar))) wer = we / total_word cer = ce / total_word return wer, cer, total_word, total_char
class Trainer(): def __init__(self, model, criterion, optimizer, opt, writer): self.model = model self.criterion = criterion self.optimizer = optimizer self.batch_time = AverageMeter() self.data_time = AverageMeter() self.losses = AverageMeter() self.writer = writer def train(self, trainloader, epoch, opt): self.data_time.reset() self.batch_time.reset() self.model.train() self.losses.reset() end = time.time() for i, data in enumerate(trainloader, 0): self.optimizer.zero_grad() xh, xi, xp, shifted_targets, eyes, names, eyes2, gcorrs = data xh = xh.cpu() xi = xi.cpu() xp = xp.cpu() shifted_targets = shifted_targets.cpu().squeeze() self.data_time.update(time.time() - end) outputs = self.model(xh, xi, xp) total_loss = self.criterion(outputs[0], shifted_targets[:, 0, :].max(1)[1]) for j in range(1, len(outputs)): total_loss += self.criterion( outputs[j], shifted_targets[:, j, :].max(1)[1]) total_loss = total_loss / (len(outputs) * 1.0) total_loss.backward() self.optimizer.step() inputs_size = xh.size(0) self.losses.update(total_loss.item(), inputs_size) self.batch_time.update(time.time() - end) end = time.time() if i % opt.printfreq == 0 and opt.verbose: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.avg:.3f} ({batch_time.sum:.3f})\t' 'Data {data_time.avg:.3f} ({data_time.sum:.3f})\t' 'Loss {loss.avg:.3f}\t'.format( epoch, i, len(trainloader), batch_time=self.batch_time, data_time=self.data_time, loss=self.losses)) sys.stdout.flush() self.writer.add_scalar('Train Loss', self.losses.avg, epoch) print('Train: [{0}]\t' 'Time {batch_time.sum:.3f}\t' 'Data {data_time.sum:.3f}\t' 'Loss {loss.avg:.3f}\t'.format(epoch, batch_time=self.batch_time, data_time=self.data_time, loss=self.losses))
class Validator(): def __init__(self, model, criterion, opt, writer): self.model = model self.criterion = criterion self.batch_time = AverageMeter() self.data_time = AverageMeter() self.dist = AverageMeter() self.mindist = AverageMeter() self.writer = writer def validate(self, valloader, epoch, opt): self.model.eval() self.dist.reset() self.mindist.reset() self.data_time.reset() self.batch_time.reset() end = time.time() with torch.no_grad(): for i, data in enumerate(valloader, 0): xh, xi, xp, targets, eyes, names, eyes2, ground_labels = data xh = xh.cpu() xi = xi.cpu() xp = xp.cpu() self.data_time.update(time.time() - end) outputs = self.model.predict(xh, xi, xp) pred_labels = outputs.max(1)[1] inputs_size = xh.size(0) distval = utils.euclid_dist(pred_labels.data.cpu(), ground_labels, inputs_size) # mindistval = utils.euclid_mindist(pred_labels.data.cpu(), ground_labels, inputs_size) self.dist.update(distval, inputs_size) #self.mindist.update(mindistval, inputs_size) self.batch_time.update(time.time() - end) end = time.time() if i % opt.printfreq == 0 and opt.verbose: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Dist {dist.avg:.3f}\t'.format( epoch, i, len(valloader), batch_time=self.batch_time, data_time=self.data_time, dist=self.dist)) sys.stdout.flush() self.writer.add_scalar('Val Dist', self.dist.avg, epoch) #self.writer.add_scalar('Val Min Dist', self.mindist.avg, epoch) print('Val: [{0}]\t' 'Time {batch_time.sum:.3f}\t' 'Data {data_time.sum:.3f}\t' 'Dist {dist.avg:.3f}\t'.format(epoch, batch_time=self.batch_time, data_time=self.data_time, dist=self.dist)) return self.dist.avg
class Trainer(object): # the most basic model def __init__(self, config, data_loader=None): if (config.w_minWvar > 0): config.minimize_W_var = True self.varLoss = var_mask() self.config = config self.data_loader = data_loader # needed for VAE self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.optimizer = config.optimizer self.batch_size = config.batch_size self.diffLoss = L1Loss_mask() # custom module log_domain = False if (self.config.linear_to_mel): log_domain = True self.get_SNRout = get_SNRout(log_domain=log_domain) self.valmin_iter = 0 self.model_dir = 'models/' + str(config.expnum) self.log_dir = 'logs_only/' + str(config.expnum) self.savename_G = '' self.decoder = GreedyDecoder(data_loader.labels) self.kt = 0 # used for Proportional Control Theory in BEGAN, initialized as 0 self.lb = 0.001 self.conv_measure = 0 # convergence measure self.dce_tr = AverageMeter() self.dce_val = AverageMeter() self.snrout_tr = AverageMeter() self.snrout_val = AverageMeter() self.snrimpv_tr = AverageMeter() self.snrimpv_val = AverageMeter() if (config.linear_to_mel): self.mel_basis = librosa.filters.mel(self.config.fs, self.config.nFFT, self.config.nMel) self.melF_to_linearFs = get_linearF_from_melF(self.mel_basis) self.STFT_to_LMFB = STFT_to_LMFB(self.mel_basis, window_change=False) self.mag2mfb = linearmag2mel(self.mel_basis) mel_basis_20ms = librosa.filters.mel( self.config.fs, 320, self.config.nMel ) # mel_basis will be used only for 20ms window spectrogram self.STFT_to_LMFB_20ms = STFT_to_LMFB(mel_basis_20ms, win_size=self.config.nFFT) self.F = int(self.config.nFFT / 2 + 1) self.build_model() self.G.loss_stop = 100000 #self.get_weight_statistic() if self.config.gpu >= 0: self.G.cuda() if len(self.config.load_path) > 0: self.load_model() if config.mode == 'train': self.logFile = open(self.log_dir + '/log.txt', 'w') def zero_grad_all(self): self.G.zero_grad() def build_model(self): self.G = LineartoMel_real(F=self.F, melF_to_linearFs=self.melF_to_linearFs, nCH=self.config.nCH, w=self.config.convW, H=self.config.nMap_per_F, L=self.config.L_CNN, non_linear=self.config.non_linear, BN=self.config.complex_BN) # 현재 사용중인 모델 G_name = 'LineartoMel_real' print('initialized enhancement model as ' + G_name) nParam = count_parameters(self.G) print('# trainable parameters = ' + str(nParam)) def load_model(self): print("[*] Load models from {}...".format(self.config.load_path)) postfix = '_valmin' paths = glob( os.path.join(self.config.load_path, 'G{}*.pth'.format(postfix))) paths.sort() if len(paths) == 0: print("[!] No checkpoint found in {}...".format(self.load_path)) assert (0), 'checkpoint not avilable' idxes = [ int(os.path.basename(path.split('.')[0].split('_')[-1])) for path in paths ] if self.config.start_iter <= 0: self.config.start_iter = max(idxes) if (self.config.start_iter <= 0): # if still 0, then raise error raise Exception( "start iter is still less than 0 --> probably try to load initial random model" ) if self.config.gpu < 0: #CPU map_location = lambda storage, loc: storage else: # GPU map_location = None # Ver2 print('Load models from ' + self.config.load_path + ', ITERATION = ' + str(self.config.start_iter)) self.G.load_state_dict( torch.load('{}/G{}_{}.pth'.format(self.config.load_path, postfix, self.config.start_iter), map_location=map_location)) print("[*] Model loaded") def train(self): # Setting optimizer_g = torch.optim.Adam(self.G.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) for iter in trange(self.config.start_iter, self.config.max_iter): # Train data_list = self.data_loader.next(cl_ny='ny', type='train') inputs, cleans, mask = data_list[0], data_list[1], data_list[ 2] # cleans: NxFxT, mask: Nx1xT if (len(data_list) >= 9): mixture_magnitude = data_list[7] mixture_phsdiff = data_list[8] inputs_augmented = torch.cat( (torch.log(1 + mixture_magnitude), mixture_phsdiff), dim=2) mfb = self.mag2mfb(mixture_magnitude) cleans = self.STFT_to_LMFB(cleans) if (self.config.linear_to_mel): inputs = [_get_variable(inputs_augmented), _get_variable(mfb)] else: inputs = _get_variable(inputs) cleans = _get_variable(cleans) mask = _get_variable(mask) # forward outputs = self.G( inputs ) # forward(입력(=[log(magnitude) phase difference]-->출력(=log-mel-filterbank output)) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function if (self.config.loss_per_freq): if (iter + 1) % self.config.log_iter == 0: for f in range(dce.size(0)): str_loss = "[{}/{}] (train) DCE_{}: {:.7f}".format( iter, self.config.max_iter, f, dce[f].sum().item()) self.logFile.write(str_loss + '\n') dce = dce.sum() # sum up all the loss total_loss = dce # backward self.zero_grad_all() total_loss.backward() optimizer_g.step() # log #pdb.set_trace() if (iter + 1) % self.config.log_iter == 0: #pdb.set_trace() str_loss = "[{}/{}] (train) DCE: {:.7f}".format( iter, self.config.max_iter, dce.item()) print(str_loss) self.logFile.write(str_loss + '\n') SNRout = self.get_SNRout(outputs, cleans, mask) SNRout = SNRout.sum() / cleans.size(0) str_loss = "[{}/{}] (train) SNRout: {:.7f}".format( iter, self.config.max_iter, SNRout.item()) print(str_loss) self.logFile.write(str_loss + '\n') self.logFile.flush() if (iter + 1) % self.config.save_iter == 0: with torch.no_grad(): self.G.eval() self.diffLoss.eval() # Measure performance on training subset self.dce_tr.reset() self.snrout_tr.reset() self.snrimpv_tr.reset() for _ in trange(0, len(self.data_loader.trsub_dl)): data_list = self.data_loader.next(cl_ny='ny', type='trsub') inputs, cleans, mask = data_list[0], data_list[ 1], data_list[2] if (len(data_list) >= 6): targets, input_percentages, target_sizes = data_list[ 3], data_list[4], data_list[5] if (len(data_list) >= 7): SNRin_1s = _get_variable(data_list[6]) if (len(data_list) >= 9): mixture_magnitude = data_list[7] mixture_phsdiff = data_list[8] inputs_augmented = torch.cat( (torch.log(1 + mixture_magnitude), mixture_phsdiff), dim=2) mfb = self.mag2mfb(mixture_magnitude) cleans = self.STFT_to_LMFB(cleans) cleans, mask = _get_variable(cleans), _get_variable( mask) if (self.config.linear_to_mel): inputs = [ _get_variable(inputs_augmented), _get_variable(mfb) ] else: inputs = _get_variable(inputs) # Forward (of training subset) outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function self.dce_tr.update(dce.item(), nElement) SNRout = self.get_SNRout(outputs, cleans, mask) SNRimprovement = SNRout - SNRin_1s SNRout = SNRout.sum() / cleans.size(0) SNRimprovement = SNRimprovement.sum() / cleans.size(0) self.snrout_tr.update(SNRout.item(), cleans.size(0)) self.snrimpv_tr.update(SNRimprovement.item(), cleans.size(0)) str_loss = "[{}/{}] (training subset) DCE: {:.7f}".format( iter, self.config.max_iter, self.dce_tr.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (training subset) SNRout: {:.7f}".format( iter, self.config.max_iter, self.snrout_tr.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (training subset) SNRimprovement: {:.7f}".format( iter, self.config.max_iter, self.snrimpv_tr.avg) print(str_loss) self.logFile.write(str_loss + '\n') # Measure performance on validation data self.dce_val.reset() self.wer_val.reset() self.cer_val.reset() self.snrout_tr.reset() self.snrimpv_tr.reset() for _ in trange(0, len(self.data_loader.val_dl)): data_list = self.data_loader.next(cl_ny='ny', type='val') inputs, cleans, mask = data_list[0], data_list[ 1], data_list[2] if (len(data_list) >= 6): targets, input_percentages, target_sizes = data_list[ 3], data_list[4], data_list[5] if (len(data_list) >= 7): SNRin_1s = _get_variable(data_list[6]) if (len(data_list) >= 9): mixture_magnitude = data_list[7] mixture_phsdiff = data_list[8] mfb = self.mag2mfb(mixture_magnitude) inputs_augmented = torch.cat( (torch.log(1 + mixture_magnitude), mixture_phsdiff), dim=2) cleans = self.STFT_to_LMFB(cleans) cleans, mask = _get_variable(cleans), _get_variable( mask) if (self.config.linear_to_mel): inputs = [ _get_variable(inputs_augmented), _get_variable(mfb) ] else: inputs = _get_variable(inputs) # Forward (of validation) outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function self.dce_val.update(dce.item(), nElement) SNRout = self.get_SNRout(outputs, cleans, mask) SNRimprovement = SNRout - SNRin_1s SNRout = SNRout.sum() / cleans.size(0) SNRimprovement = SNRimprovement.sum() / cleans.size(0) self.snrout_val.update(SNRout.item(), cleans.size(0)) self.snrimpv_val.update(SNRimprovement.item(), cleans.size(0)) str_loss = "[{}/{}] (validation) DCE: {:.7f}".format( iter, self.config.max_iter, self.dce_val.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (validation) SNRout: {:.7f}".format( iter, self.config.max_iter, self.snrout_val.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (validation) SNRimprovement: {:.7f}".format( iter, self.config.max_iter, self.snrimpv_val.avg) print(str_loss) self.logFile.write(str_loss + '\n') self.G.train() # end of validation self.diffLoss.train() self.logFile.flush() # Save model if (len(self.savename_G) > 0): # do not remove here if os.path.exists(self.savename_G): os.remove(self.savename_G) # remove previous model self.savename_G = '{}/G_{}.pth'.format( self.model_dir, iter) torch.save(self.G.state_dict(), self.savename_G) if (self.G.loss_stop > self.wer_val.avg): self.G.loss_stop = self.wer_val.avg savename_G_valmin_prev = '{}/G_valmin_{}.pth'.format( self.model_dir, self.valmin_iter) if os.path.exists(savename_G_valmin_prev): os.remove(savename_G_valmin_prev ) # remove previous model print('save model for this checkpoint') savename_G_valmin = '{}/G_valmin_{}.pth'.format( self.model_dir, iter) copyfile(self.savename_G, savename_G_valmin) self.valmin_iter = iter