Exemplo n.º 1
0
    def train_model(self, data_loader, criterion, optimizer):
        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.train()
        self.to(device)
        loss_avg = AverageMeter()
        acc_avg = AverageMeter()
        loss_avg.reset()
        acc_avg.reset()

        for label, image in tqdm(data_loader):
            image: torch.Tensor = image.to(device)
            label: torch.Tensor = label.to(device).float().unsqueeze(dim=1)

            pred: torch.Tensor = self.forward(image)

            loss: torch.Tensor = criterion(pred, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_avg.update(loss.item())
            pred = (pred + 0.5).long()
            num_correct = (pred == label).sum().item()
            acc_avg.update(num_correct / image.shape[0])

        return loss_avg.avg, acc_avg.avg
 def validate(self, epoch):
     self.model.eval()
     losses = AverageMeter()
     times = AverageMeter()
     losses.reset()
     times.reset()
     len_d = len(self.valid_loader)
     end = time.time()
     with torch.no_grad():
         for i, data in enumerate(self.valid_loader):
             input, label = data
             input = [ele.to(self.device) for ele in input]
             label = [ele.to(self.device) for ele in label]
             output = self.model(input)
             bat_val_loss = self.loss_fn(output, label)
             bat_val_loss_avg = torch.mean(bat_val_loss)
             losses.update(bat_val_loss_avg.item())
             times.update(time.time() - end)
             end = time.time()
             writer.add_scalar('valid_loss/batch_loss', bat_val_loss_avg, epoch * len_d + i)
             print('epoch %d, %d/%d, validation loss: %f, time estimated: %.2f seconds' % (epoch, i + 1, len_d, bat_val_loss_avg, times.avg * len_d), end='\r')
         print("\n")
     writer.add_scalar('valid_loss/valid_loss', losses.avg, epoch)
     if losses.avg < self.min_loss:
         self.early_stop_count = 0
         self.min_loss = losses.avg
         torch.save(self.model.state_dict(),self.output_path+"/model.epoch%d"%epoch)
         print("Saved new model")
     else:
         self.early_stop_count += 1
 def train(self, epoch):
     losses = AverageMeter()
     times = AverageMeter()
     losses.reset()
     times.reset()
     self.model.train()
     len_d = len(self.train_loader)
     end = time.time()
     for i, data in enumerate(self.train_loader):
         input, label = data
         input = [ele.to(self.device) for ele in input]
         label = [ele.to(self.device) for ele in label]
         output = self.model(input)
         bat_loss = self.loss_fn(output, label)
         bat_loss_avg = torch.mean(bat_loss)
         losses.update(bat_loss_avg.item())
         self.optimizer.zero_grad()
         bat_loss_avg.backward()
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
         self.optimizer.step()
         times.update(time.time() - end)
         end = time.time()
         writer.add_scalar('train_loss/batch_loss', bat_loss_avg, epoch * len_d + i)
         print('epoch %d, %d/%d, training loss: %f, time estimated: %.2f seconds' % (epoch, i + 1, len_d, bat_loss_avg, times.avg * len_d), end='\r')
     self.scheduler.step()
     print("\n")
     writer.add_scalar('train_loss/train_loss', losses.avg, epoch)
Exemplo n.º 4
0
 def train(self, epoch):
     losses = AverageMeter()
     times = AverageMeter()
     losses.reset()
     times.reset()
     self.model.train()
     len_d = len(self.train_loader)
     init_time = time.time()
     end = init_time
     for i, data in enumerate(self.train_loader):
         input, label = data
         output = self.model(input)
         loss = self.loss_fn(output, label)
         loss_avg = torch.mean(loss)
         losses.update(loss_avg.item())
         self.optimizer.zero_grad()
         loss_avg.backward()
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
         self.optimizer.step()
         times.update(time.time() - end)
         end = time.time()
         print(
             'epoch %d, %d/%d, training loss: %f, time estimated: %.2f/%.2f seconds'
             % (epoch, i + 1, len_d, losses.avg, end - init_time,
                times.avg * len_d),
             end='\r')
     print("\n")
Exemplo n.º 5
0
 def test(self):
     times = AverageMeter()
     times.reset()
     len_d = len(self.test_loader)
     end = time.time()
     with torch.no_grad():
         for i, data in enumerate(self.test_loader):
             input, infdat = data
             input = [ele.to(self.device) for ele in input]
             output = self.model(input)
             audio_out = output[0].squeeze().cpu().detach().numpy()
             out_path = self.output_path + '/out'
             if not os.path.exists(out_path):
                 os.makedirs(out_path)
             fn = out_path + '/' + infdat[0][0]
             audio_out = audio_out.squeeze().T
             sf.write(fn,
                      audio_out,
                      self.feature_options.sampling_rate,
                      subtype='PCM_16')
             times.update(time.time() - end)
             end = time.time()
             print('%d/%d, time estimated: %.2f seconds' %
                   (i + 1, len_d, times.avg * len_d),
                   end='\r')
     print("\n")
Exemplo n.º 6
0
    def _train_one_epoch(self, epoch):
        self.model.train()
        loss_meter = AverageMeter()
        time_meter = TimeMeter()
        for bid, (video, video_mask, words, word_mask,
                  label, scores, scores_mask, id2pos, node_mask, adj_mat) in enumerate(self.train_loader, 1):
            self.optimizer.zero_grad()

            model_input = {
                'frames': video.cuda(),
                'frame_mask': video_mask.cuda(), 'words': words.cuda(), 'word_mask': word_mask.cuda(),
                'label': scores.cuda(), 'label_mask': scores_mask.cuda(), 'gt': label.cuda(),
                'node_pos': id2pos.cuda(), 'node_mask': node_mask.cuda(), 'adj_mat': adj_mat.cuda()
            }

            predict_boxes, loss, _, _, _ = self.model(**model_input)
            loss = torch.mean(loss)
            self.optimizer.backward(loss)

            self.optimizer.step()
            self.num_updates += 1
            curr_lr = self.lr_scheduler.step_update(self.num_updates)

            loss_meter.update(loss.item())
            time_meter.update()

            if bid % self.args.display_n_batches == 0:
                logging.info('Epoch %d, Batch %d, loss = %.4f, lr = %.5f, %.3f seconds/batch' % (
                    epoch, bid, loss_meter.avg, curr_lr, 1.0 / time_meter.avg
                ))
                loss_meter.reset()
Exemplo n.º 7
0
def val(args, model=None, current_epoch=0):
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1.reset()
    top5.reset()

    if model is None:
        model = get_model(args)
    model.eval()
    _, val_loader = data_loader(args, test_path=True)

    save_atten = SAVE_ATTEN(save_dir=args.save_atten_dir)

    global_counter = 0
    prob = None
    gt = None
    for idx, dat in tqdm(enumerate(val_loader)):
        img_path, img, label_in = dat
        global_counter += 1
        if args.tencrop == 'True':
            bs, ncrops, c, h, w = img.size()
            img = img.view(-1, c, h, w)
            label_input = label_in.repeat(10, 1)
            label = label_input.view(-1)
        else:
            label = label_in

        img, label = img.cuda(), label.cuda()
        img_var, label_var = Variable(img), Variable(label)

        logits = model(img_var, label_var)

        logits0 = logits[0]
        logits0 = F.softmax(logits0, dim=1)
        if args.tencrop == 'True':
            logits0 = logits0.view(bs, ncrops, -1).mean(1)

        # Calculate classification results
        prec1_1, prec5_1 = Metrics.accuracy(logits0.cpu().data,
                                            label_in.long(),
                                            topk=(1, 5))
        # prec3_1, prec5_1 = Metrics.accuracy(logits[1].data, label.long(), topk=(1,5))
        top1.update(prec1_1[0], img.size()[0])
        top5.update(prec5_1[0], img.size()[0])

        # save_atten.save_heatmap_segmentation(img_path, np_last_featmaps, label.cpu().numpy(),
        #                                      save_dir='./save_bins/heatmaps', size=(0,0), maskedimg=True)
        #

        np_last_featmaps = logits[2].cpu().data.numpy()
        np_last_featmaps = logits[-1].cpu().data.numpy()
        np_scores, pred_labels = torch.topk(logits0, k=args.num_classes, dim=1)
        pred_np_labels = pred_labels.cpu().data.numpy()
        save_atten.save_top_5_pred_labels(pred_np_labels[:, :5], img_path,
                                          global_counter)

        # pred_np_labels[:,0] = label.cpu().numpy() #replace the first label with gt label
        # save_atten.save_top_5_atten_maps(np_last_featmaps, pred_np_labels, img_path)

    print('Top1:', top1.avg, 'Top5:', top5.avg)
Exemplo n.º 8
0
 def validate(self, epoch):
     self.model.eval()
     losses = AverageMeter()
     times = AverageMeter()
     losses.reset()
     times.reset()
     len_d = len(self.valid_loader)
     end = time.time()
     for i, data in enumerate(self.valid_loader):
         begin = time.time()
         input, label = data
         if torch.sum(label[0]) < 1:
             continue
         output = self.model(input)
         loss = self.loss_fn(output, label)
         loss_avg = torch.mean(loss)
         losses.update(loss_avg.item())
         times.update(time.time() - end)
         end = time.time()
         print(
             'epoch %d, %d/%d, validation loss: %f, time estimated: %.2f seconds'
             % (epoch, i + 1, len_d, losses.avg, times.avg * len_d),
             end='\r')
     print("\n")
     if losses.avg < self.min_loss:
         self.early_stop_count = 0
         self.min_loss = losses.avg
         torch.save(self.model, self.output_path + "/model.epoch%d" % epoch)
         print("Saved new model")
     else:
         self.early_stop_count += 1
def train(epoch):
    net.train()
    total_loss = AverageMeter()
    epoch_loss_stats = AverageMeter()
    start_time = time.time()

    bar = tqdm(enumerate(train_loader))
    for batch_idx, (inputs, labels) in bar:
        inputs = Variable(inputs)
        labels = Variable(labels)
        if args.cuda:
            inputs = inputs.cuda()
            labels = labels.cuda()
        optimizer.zero_grad()
        outputs = net(inputs)
        outputs = F.log_softmax(outputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss.update(loss.data[0], inputs.size(0))
        epoch_loss_stats.update(loss.data[0], inputs.size(0))

        if args.visdom is not None:
            cur_iter = batch_idx + (epoch - 1) * len(train_loader)
            vis.plot_line('train_plt',
                          X=torch.ones((1, )).cpu() * cur_iter,
                          Y=loss.data.cpu(),
                          update='append')

        if batch_idx % args.backup_iters == 0:
            filename = 'texture_{0}_snapshot.pth'.format(args.split)
            filename = osp.join(args.save_folder, filename)
            state_dict = net.state_dict()
            torch.save(state_dict, filename)

            optim_filename = 'texture_{0}_optim.pth'.format(args.split)
            optim_filename = osp.join(args.save_folder, optim_filename)
            state_dict = optimizer.state_dict()
            torch.save(state_dict, optim_filename)

        if batch_idx % args.log_interval == 0:
            elapsed_time = time.time() - start_time
            bar.set_description('[{:5d}] ({:5d}/{:5d}) | ms/batch {:.6f} |'
                                ' loss {:.6f} | lr {:.7f}'.format(
                                    epoch, batch_idx, len(train_loader),
                                    elapsed_time * 1000, total_loss.avg,
                                    optimizer.param_groups[0]['lr']))
            total_loss.reset()

        start_time = time.time()

    epoch_total_loss = epoch_loss_stats.avg

    if args.visdom is not None:
        vis.plot_line('train_epoch_plt',
                      X=torch.ones((1, )).cpu() * epoch,
                      Y=torch.ones((1, )).cpu() * epoch_total_loss,
                      update='append')
    return epoch_total_loss
Exemplo n.º 10
0
def train(epoch, relative_age=True):
    net.train()
    total_loss = AverageMeter()
    epoch_loss_stats = AverageMeter()
    time_stats = AverageMeter()
    loss = 0
    optimizer.zero_grad()
    for (batch_idx, (imgs, bone_ages, genders, chronological_ages,
                     _)) in enumerate(train_loader):
        imgs = imgs.to(device)
        bone_ages = bone_ages.to(device)
        genders = genders.to(device)
        chronological_ages = chronological_ages.to(device)
        if relative_age:
            relative_ages = chronological_ages.squeeze(1) - bone_ages

        start_time = time.time()
        outputs = net(imgs, genders, chronological_ages)
        if relative_age:
            loss = criterion(outputs.squeeze(), relative_ages)
        else:
            loss = criterion(outputs.squeeze(), bone_ages)
        loss.backward()
        optimizer.step()

        loss = metric_average(loss.item(), 'loss')

        time_stats.update(time.time() - start_time, 1)
        total_loss.update(loss, 1)
        epoch_loss_stats.update(loss, 1)
        optimizer.zero_grad()

        if (batch_idx % args.log_interval == 0) and args.rank == 0:
            elapsed_time = time_stats.avg
            print(' [{:5d}] ({:5d}/{:5d}) | ms/batch {:.4f} |'
                  ' loss {:.6f} | avg loss {:.6f} | lr {:.7f}'.format(
                      epoch, batch_idx, len(train_loader), elapsed_time * 1000,
                      total_loss.avg, epoch_loss_stats.avg,
                      optimizer.param_groups[0]['lr']))
            total_loss.reset()

    epoch_total_loss = epoch_loss_stats.avg
    args.resume_iter = 0

    if args.rank == 0:
        filename = 'boneage_bonet_snapshot.pth'
        filename = osp.join(args.save_folder, filename)
        torch.save(net.state_dict(), filename)

        optim_filename = 'boneage_bonet_optim.pth'
        optim_filename = osp.join(args.save_folder, optim_filename)
        torch.save(optimizer.state_dict(), optim_filename)

    return epoch_total_loss
Exemplo n.º 11
0
def train(epoch):
    net.train()
    total_loss = AverageMeter()
    # total_loss = 0
    epoch_loss_stats = AverageMeter()
    # epoch_total_loss = 0
    start_time = time.time()

    for i_batch, sample_batched in enumerate(train_loader):
        im = Variable(sample_batched[0])
        label = Variable(sample_batched[1])
        if args.cuda:
            im = im.cuda()
            label = label.cuda()

        optimizer.zero_grad()
        out_masks = net(im, label)
        out_masks = out_masks.cuda()
        loss = criterion(out_masks, label)
        loss.backward()
        optimizer.step()
        total_loss.update(loss.data[0], im.size(0))
        epoch_loss_stats.update(loss.data[0], im.size(0))

        if i_batch % args.backup_iters == 0:
            filename = 'casenet_{0}_{1}_snapshot.pth'.format(
                args.dataset, args.split)
            filename = osp.join(args.save_folder, filename)
            state_dict = net.state_dict()
            torch.save(state_dict, filename)

            optim_filename = 'casenet_{0}_{1}_optim.pth'.format(
                args.dataset, args.split)
            optim_filename = osp.join(args.save_folder, optim_filename)
            state_dict = optimizer.state_dict()
            torch.save(state_dict, optim_filename)

        if i_batch % args.log_interval == 0:
            elapsed_time = time.time() - start_time
            # cur_loss = total_loss / args.log_interval
            print('[{:5d}] ({:5d}/{:5d}) | ms/batch {:.6f} |'
                  ' loss {:.6f} | lr {:.7f}'.format(epoch, i_batch,
                                                    len(train_loader),
                                                    elapsed_time * 1000,
                                                    total_loss.avg,
                                                    scheduler.get_lr()[0]))
            total_loss.reset()

        start_time = time.time()

    epoch_total_loss = epoch_loss_stats.avg

    return epoch_total_loss
Exemplo n.º 12
0
def test(model,
         test_loader,
         epoch,
         margin,
         threshlod,
         is_cuda=True,
         log_interval=1000):
    model.eval()
    test_loss = AverageMeter()
    accuracy = 0
    num_p = 0
    total_num = 0
    batch_num = len(test_loader)
    for batch_idx, (data_a, data_p, data_n, target) in enumerate(test_loader):
        if is_cuda:
            data_a = data_a.cuda()
            data_p = data_p.cuda()
            data_n = data_n.cuda()
            target = target.cuda()

        data_a = Variable(data_a, volatile=True)
        data_p = Variable(data_p, volatile=True)
        data_n = Variable(data_n, volatile=True)
        target = Variable(target)

        out_a = model(data_a)
        out_p = model(data_p)
        out_n = model(data_n)

        loss = F.triplet_margin_loss(out_a, out_p, out_n, margin)

        dist1 = F.pairwise_distance(out_a, out_p)
        dist2 = F.pairwise_distance(out_a, out_n)
        #print('dist1', dist1)
        #print('dist2',dist2)
        #print('threshlod', threshlod)

        num = ((dist1 < threshlod).sum() + (dist2 > threshlod).sum()).data[0]
        num_p += num
        num_p = 1.0 * num_p
        total_num += data_a.size()[0] * 2
        #print('num--num_p -- total',  num, num_p , total_num)
        test_loss.update(loss.data[0])
        if (batch_idx + 1) % log_interval == 0:
            accuracy_tmp = num_p / total_num
            print('Test- Epoch {:04d}\tbatch:{:06d}/{:06d}\tAccuracy:{:.04f}\tloss:{:06f}'\
                    .format(epoch, batch_idx+1, batch_num, accuracy_tmp, test_loss.avg))
            test_loss.reset()

    accuracy = num_p / total_num
    return accuracy
Exemplo n.º 13
0
    def run(self, epochs, lr_decay=False, mixup=False):
        train_loss_meter = AverageMeter()
        min_val_loss = 1e8
        tolerance = self.early_stopping_tolerance

        self.scheduler_setup(lr_decay)
        for epoch in range(1, epochs + 1):
            losses = train(
                self.model,
                self.device,
                self.train_loader,
                self.optimizer,
                self.criterion,
                epoch,
                mixup,
                train_loss_meter,
            )  # Returns loss per batch
            self.train_loss.extend(losses)

            val_loss, val_accuracy = test(
                self.model, self.device, self.val_loader,
                self.criterion)  # Returns loss/accuracy per epoch
            self.test_loss.append(val_loss)
            self.test_accuracy.append(val_accuracy)

            self.scheduler_step(lr_decay, epoch)

            print(
                f"Epoch {epoch} \t"
                f"train_loss: {train_loss_meter.average:.6f}"
                f"\tval_loss: {val_loss:.4f}\tval_accuracy: {val_accuracy * 100:.2f}%"
            )
            train_loss_meter.reset()

            if val_loss < min_val_loss:
                min_val_loss = val_loss
                tolerance = (
                    self.early_stopping_tolerance
                )  # Reset the tolerance because validation loss improved
            else:
                if epoch > 20:  # Early stopping doesn't start before 20 epochs
                    tolerance -= 1

            if tolerance == 0:
                # Early stopping the training process
                print(f"\nEarly stopping. Val loss did not improve for "
                      f"{self.early_stopping_tolerance} consecutive epochs")
                break
Exemplo n.º 14
0
def do_test(cfg, model, test_loader, experiment_name):
    test_acc_meter = AverageMeter()
    test_acc_meter.reset()

    device = cfg.MODEL.DEVICE

    logger = logging.getLogger('{}.test'.format(cfg.PROJECT.NAME))
    logger.info("Enter Image Classification Test")

    if device:
        if torch.cuda.device_count() > 1:
            print('Using {} GPUs for test'.format(
                torch.cuda.device_count()))
            model = nn.DataParallel(model)
        model.to(device)

    # generate result csv
    output_dir = os.path.join(cfg.MODEL.OUTPUT_PATH, experiment_name)
    result_path = os.path.join(
        output_dir,
        experiment_name  + '_' + 'test_result.csv')
    with open(result_path, 'w') as f:
        f.write("file_name,label,predictive_label")

    model.eval()
    for iteration, (img, vid, vname) in enumerate(test_loader):
        with torch.no_grad():
            img = img.to(device)
            vid = torch.tensor(vid)
            target = vid.to(device)            
            score = model(img)
            p_label = score.max(1)[1]
            acc = (score.max(1)[1] == target).float().mean()
            test_acc_meter.update(acc, 1)

            logger.info(
                "Iteration[{}/{}], Test_Acc: {:.3f}"
                .format((iteration + 1), len(test_loader), test_acc_meter.avg))

        with open(result_path, 'a+') as f:
            for i in range(len(vid)):
                name = list(vname)[i]
                label = str(vid[i].item())
                p_label_ = str(p_label[i].item())
                f.write('\n')
                f.write(name +','+label+ ',' +p_label_)
Exemplo n.º 15
0
    def process_epoch(self, train):
        if train:
            data_loader = self.train_loader
        else:
            data_loader = self.val_loader

        loss_unatt_agg = AverageMeter()
        acc_unatt_agg = AverageMeter()
        loss_att_agg = AverageMeter()
        acc_att_agg = AverageMeter()

        for i, (input, target) in enumerate(data_loader):
            if self.target:
                target = self.target * target.new_ones(target.size())
            input, target = input.to(self.device), target.to(self.device)

            with torch.no_grad():
                output = self.classifier(input)
                los_unatt = self.criterion(output, target)
            loss_unatt_agg.update(los_unatt.item(), input.size(0))
            acc_unatt_agg.update(accuracy(output, target).item(), input.size(0))

            with torch.set_grad_enabled(train):
                input_att, _ = self.framing(input=input)
                output_att = self.classifier(input_att)
                loss_att = self.criterion(output_att, target)
            loss_att_agg.update(loss_att.item(), input.size(0))
            acc_att_agg.update(accuracy(output_att, target).item(), input.size(0))

            if train:
                self.optimizer.zero_grad()
                framing_loss = loss_att if self.target is not None else -loss_att
                framing_loss.backward()
                self.optimizer.step()
                self.step += input.size(0)

            if train:
                if (i + 1) % self.args.print_freq == 0:
                    self.logger.log_kv([
                        ('unatt_loss', loss_unatt_agg.avg),
                        ('att_loss', loss_att_agg.avg),
                        ('unatt_acc', acc_unatt_agg.avg),
                        ('att_acc', acc_att_agg.avg),
                    ], prefix='train', step=self.step, write_to_tb=True)
                    loss_unatt_agg.reset()
                    loss_att_agg.reset()
                    acc_unatt_agg.reset()
                    acc_att_agg.reset()
            else:
                if i + 1 == len(data_loader):
                    self.logger.log_kv([
                        ('unatt_loss', loss_unatt_agg.avg),
                        ('att_loss', loss_att_agg.avg),
                        ('unatt_acc', acc_unatt_agg.avg),
                        ('att_acc', acc_att_agg.avg),
                    ], prefix='eval', step=self.step, write_to_tb=True, write_to_file=True)
Exemplo n.º 16
0
def train(epoch):
    """Train of the net."""
    net.train()
    total_loss = AverageMeter()
    epoch_loss_stats = AverageMeter()
    start_time = time.time()

    bar = tqdm(enumerate(train_loader))
    for batch_idx, sample in bar:
        optimizer.zero_grad()
        # TODO: Call the train routine for the net
        # outputs, label = routines.train_routine(sample)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()
        total_loss.update(loss.data, n=outputs.size(0))
        epoch_loss_stats.update(loss.data, n=outputs.size(0))

        if batch_idx % cfg.TRAIN.BACKUP_ITERS == 0:
            filename = '{0}_snapshot.pth'.format(cfg.DATASET.SPLIT)
            filename = osp.join(cfg.OUTPUT_DIR, filename)
            state_dict = net.state_dict()
            torch.save(state_dict, filename)

            optim_filename = '{0}_optim.pth'.format(cfg.DATASET.SPLIT)
            optim_filename = osp.join(cfg.OUTPUT_DIR, optim_filename)
            state_dict = optimizer.state_dict()
            torch.save(state_dict, optim_filename)

        if batch_idx % cfg.TRAIN.LOG_INTERVAL == 0:
            elapsed_time = time.time() - start_time
            bar.set_description('[{:5d}] ({:5d}/{:5d}) | ms/batch {:.6f} |'
                                ' loss {:.6f} | lr {:.7f}'.format(
                                    epoch, batch_idx, len(train_loader),
                                    elapsed_time * 1000, total_loss.avg,
                                    optimizer.param_groups[0]['lr']))
            total_loss.reset()

        start_time = time.time()

    epoch_total_loss = epoch_loss_stats.avg
    return epoch_total_loss
Exemplo n.º 17
0
 def validate(self, epoch):
     self.model.eval()
     losses = AverageMeter()
     times = AverageMeter()
     losses.reset()
     times.reset()
     len_d = len(self.valid_loader)
     init_time = time.time()
     end = init_time
     for i, data in enumerate(self.valid_loader):
         begin = time.time()
         input, label = data
         if torch.sum(label[0]) < 1:
             continue
         output = self.model(input)
         loss = self.loss_fn(output, label)
         loss_avg = torch.mean(loss)
         losses.update(loss_avg.item())
         times.update(time.time() - end)
         end = time.time()
         print(
             'epoch %d, %d/%d, validation loss: %f, time estimated: %.2f/%.2f seconds'
             % (epoch, i + 1, len_d, losses.avg, end - init_time,
                times.avg * len_d),
             end='\r')
     print("\n")
     if losses.avg < self.min_loss:
         self.early_stop_count = 0
         self.min_loss = losses.avg
         saved_dict = {
             'model': self.model.state_dict(),
             'epoch': epoch,
             'optimizer': self.optimizer,
             'cv_loss': self.min_loss,
             "early_stop_count": self.early_stop_count
         }
         torch.save(saved_dict, self.output_path + "/final.mdl")
         print("Saved new model")
     else:
         self.early_stop_count += 1
Exemplo n.º 18
0
    def validate_model(self, data_loader, criterion):
        with torch.no_grad():
            device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
            self.eval()
            self.to(device)
            loss_avg = AverageMeter()
            acc_avg = AverageMeter()
            loss_avg.reset()
            acc_avg.reset()

            for label, image in tqdm(data_loader):
                image: torch.Tensor = image.to(device)
                label: torch.Tensor = label.to(device).float().unsqueeze(dim=1)

                pred: torch.Tensor = self.forward(image)
                loss: torch.Tensor = criterion(pred, label)

                pred = pred >= 0.5
                loss_avg.update(loss.item())
                num_correct = (pred == label).sum().item()
                acc_avg.update(num_correct / image.shape[0])

            return loss_avg.avg, acc_avg.avg
class Trainer():
    def __init__(self, disp_model, pose_model, optimizer, opt):
        self.disp_model = disp_model
        self.pose_model = pose_model
        self.optimizer = optimizer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.losses = AverageMeter()

    def train(self, trainloader, epoch, opt):
        self.losses.reset()
        self.data_time.reset()
        self.batch_time.reset()
        end = time.time()
        self.disp_model.train()
        self.pose_model.train()
        for i, data in enumerate(trainloader, 0):
            self.optimizer.zero_grad()
            if opt.cuda:
                target_imgs, ref_imgs, intrinsics, intrinsics_inv = data
                target_imgs = Variable(target_imgs.cuda(async=True))
                ref_imgs = [Variable(img.cuda(async=True)) for img in ref_imgs]
                intrinsics = Variable(intrinsics.cuda(async=True))
                intrinsics_inv = Variable(intrinsics_inv.cuda(async=True))

            self.data_time.update(time.time() - end)
            disparities = self.disp_model(target_imgs)
            depths = [1 / disp for disp in disparities]
            explainability_mask, pose = self.pose_model(target_imgs, ref_imgs)

            photoloss = photometric_reconstruction_loss(
                target_imgs, ref_imgs, intrinsics, intrinsics_inv, depths,
                explainability_mask, pose, opt.rotation_mode, opt.padding_mode)
            exploss = explainability_loss(explainability_mask)
            smoothloss = smooth_loss(disparities)
            totalloss = opt.photo_loss_weight * photoloss + opt.mask_loss_weight * exploss + opt.smooth_loss_weight * smoothloss

            totalloss.backward()
            self.optimizer.step()

            inputs_size = intrinsics.size(0)
            self.losses.update(totalloss.data[0], inputs_size)
            self.batch_time.update(time.time() - end)
            end = time.time()

            if i % opt.printfreq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.avg:.3f} ({batch_time.sum:.3f})\t'
                      'Data {data_time.avg:.3f} ({data_time.sum:.3f})\t'
                      'Loss {loss.avg:.3f}\t'.format(
                          epoch,
                          i,
                          len(trainloader),
                          batch_time=self.batch_time,
                          data_time=self.data_time,
                          loss=self.losses))
            sys.stdout.flush()
Exemplo n.º 20
0
    def validate(self, epoch):
        self.model.eval()
        losses = AverageMeter()
        times = AverageMeter()

        losses_snr = AverageMeter()

        losses.reset()
        times.reset()

        losses_snr.reset()

        len_d = len(self.valid_loader)
        end = time.time()
        with torch.no_grad():
            for i, data in enumerate(self.valid_loader):
                begin = time.time()
                input, label = data
                input = [ele.to(self.device) for ele in input]
                label = [ele.to(self.device) for ele in label]

                out_spec, out_wav = self.model(input)
                loss_snr = self.loss_fn(out_wav, label)

                loss = -loss_snr
                loss_avg = torch.mean(loss)
                losses.update(loss_avg.item())

                losses_snr_avg = torch.mean(loss_snr)

                losses_snr.update(losses_snr_avg.item())

                times.update(time.time() - end)
                end = time.time()
                writer.add_scalar('valid_loss/loss(snr)', losses.avg,
                                  epoch * len_d + i + 1)
                print(
                    'epoch %d, %d/%d, validation loss: %f, time estimated: %.2f seconds'
                    % (epoch, i + 1, len_d, losses.avg, times.avg * len_d),
                    end='\r')
            print("\n")

        if losses.avg < self.min_loss:
            self.early_stop_count = 0
            self.min_loss = losses.avg
            torch.save(self.model, self.output_path + "/model.epoch%d" % epoch)
            print("Saved new model")
        else:
            self.early_stop_count += 1
Exemplo n.º 21
0
def train(model, num_epochs, resume_epoch):
    loss_meter = AverageMeter()
    rpn_meter = AverageMeter()
    frcnn_meter = AverageMeter()
    for epoch in tqdm(range(num_epochs), total=num_epochs):
        pbar = tqdm(voc_loader, total=len(voc_loader), leave=True)
        for image, target in pbar:
            image = Variable(image).cuda(async=True)
            target = target.squeeze(0).numpy()
            rpn_cls_probs, rpn_bbox_deltas, pred_label, pred_bbox_deltas = frcnn(image)
            proposal_boxes, _ = frcnn.get_rpn_proposals()

            if len(proposal_boxes) == 0:
                continue

            rpn_labels, rpn_bbox_targets, rpn_batch_indices = frcnn.get_rpn_targets(target)
            detector_labels, delta_boxes, clf_batch_indices = frcnn.get_detector_targets(target)
            rpn_loss = criterion1(rpn_cls_probs, rpn_bbox_deltas,
                                  Variable(rpn_labels, requires_grad=False).cuda(),
                                  Variable(rpn_bbox_targets, requires_grad=False).cuda(),
                                  rpn_batch_indices.cuda())
            frcnn_loss = criterion2(pred_label, pred_bbox_deltas,
                                    Variable(detector_labels, requires_grad=False).cuda(),
                                    Variable(delta_boxes, requires_grad=False).cuda(),
                                    clf_batch_indices.cuda())
            total_loss = rpn_loss + frcnn_loss

            rpn_meter.update(rpn_loss.data[0])
            frcnn_meter.update(frcnn_loss.data[0])
            loss_meter.update(total_loss.data[0])

            pbar.set_description(desc='loss {:.4f} | rpn loss {:.4f} | frcnn loss {:.4f}'.format(loss_meter.avg,
                                                                                                 rpn_meter.avg,
                                                                                                 frcnn_meter.avg))
            total_loss.backward()
            optimizer.step()

        if (epoch + 1) % CHECKPOINT_RATE == 0:
            save_checkpoint(frcnn.state_dict(),
                            optimizer.state_dict(),
                            os.path.join(WEIGHT_DIR, "{}_{:.1e}_{:.4f}.pt".format(epoch + 1 + resume_epoch,
                                                                                  LEARNING_RATE, loss_meter.avg)))

        loss_meter.reset()
        rpn_meter.reset()
        frcnn_meter.reset()
Exemplo n.º 22
0
def train(model, img_encoder, normalize, base_loader, optimizer, criterion,
          epoch, total_epoch, device, logger, nodes, desc_embeddings,
          id_to_class_name, classFile_to_wikiID):
    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    model.train()
    img_encoder.eval()
    start = time.time()

    for i, (imgs, labels, sp_labels) in enumerate(base_loader):
        data_time.update(time.time() - start)
        imgs = imgs.to(device)
        labels = labels.to(device)
        sp_labels = sp_labels.to(device)

        corr_nodeIndexs = find_nodeIndex_by_imgLabels(nodes, labels,
                                                      id_to_class_name,
                                                      classFile_to_wikiID)

        _, class_outputs, sp_outputs, att_features, corr_features = model(
            imgs, desc_embeddings, corr_nodeIndexs, norm=normalize)
        loss = criterion(class_outputs, sp_outputs, labels, sp_labels,
                         att_features, corr_features)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item())
        batch_time.update(time.time() - start)

        start = time.time()

        if i % 30 == 29:  # print every 30 mini-batches
            logger.info(
                f'[{epoch:3d}/{total_epoch}|{i+1:3d}, '
                f'{len(base_loader)}] batch_time: {batch_time.avg:.2f} '
                f'data_time: {data_time.avg:.2f} loss: {losses.avg:.3f}')

            batch_time.reset()
            data_time.reset()
            losses.reset()
Exemplo n.º 23
0
def train(model, normalize, base_loader, optimizer, criterion, epoch,
          total_epoch, device, logger):
    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    model.train()
    start = time.time()

    for i, (imgs, labels, sp_labels) in enumerate(base_loader):
        data_time.update(time.time() - start)

        imgs = imgs.to(device)
        labels = labels.to(device)
        sp_labels = sp_labels.to(device)

        _, class_outputs, sp_outputs = model(imgs, norm=normalize)
        loss = criterion(class_outputs, sp_outputs, labels, sp_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item())
        batch_time.update(time.time() - start)

        start = time.time()

        if i % 30 == 29:  # print every 30 mini-batches
            logger.info(
                f'[{epoch:3d}/{total_epoch}|{i+1:3d}, '
                f'{len(base_loader)}] batch_time: {batch_time.avg:.2f} '
                f'data_time: {data_time.avg:.2f} loss: {losses.avg:.3f}')

            batch_time.reset()
            data_time.reset()
            losses.reset()
def train(args, net, optimizer, criterion, scheduler):
    log_file = open(args.save_root + "training.log", "w", 1)
    log_file.write(args.exp_name + '\n')
    for arg in vars(args):
        print(arg, getattr(args, arg))
        log_file.write(str(arg) + ': ' + str(getattr(args, arg)) + '\n')
    log_file.write(str(net))
    net.train()

    # loss counters
    batch_time = AverageMeter()
    losses = AverageMeter()
    loc_losses = AverageMeter()
    cls_losses = AverageMeter()

    print('Loading Dataset...')
    train_dataset = UCF24Detection(args.data_root,
                                   args.train_sets,
                                   SSDAugmentation(args.ssd_dim, args.means),
                                   AnnotationTransform(),
                                   input_type=args.input_type)
    val_dataset = UCF24Detection(args.data_root,
                                 'test',
                                 BaseTransform(args.ssd_dim, args.means),
                                 AnnotationTransform(),
                                 input_type=args.input_type,
                                 full_test=False)
    epoch_size = len(train_dataset) // args.batch_size
    print('Training SSD on', train_dataset.name)

    if args.visdom:

        import visdom
        viz = visdom.Visdom()
        viz.port = 8097
        viz.env = args.exp_name
        # initialize visdom loss plot
        lot = viz.line(X=torch.zeros((1, )).cpu(),
                       Y=torch.zeros((1, 6)).cpu(),
                       opts=dict(xlabel='Iteration',
                                 ylabel='Loss',
                                 title='Current SSD Training Loss',
                                 legend=[
                                     'REG', 'CLS', 'AVG', 'S-REG', ' S-CLS',
                                     ' S-AVG'
                                 ]))
        # initialize visdom meanAP and class APs plot
        legends = ['meanAP']
        for cls in CLASSES:
            legends.append(cls)
        val_lot = viz.line(X=torch.zeros((1, )).cpu(),
                           Y=torch.zeros((1, args.num_classes)).cpu(),
                           opts=dict(xlabel='Iteration',
                                     ylabel='Mean AP',
                                     title='Current SSD Validation mean AP',
                                     legend=legends))

    batch_iterator = None
    train_data_loader = data.DataLoader(train_dataset,
                                        args.batch_size,
                                        num_workers=args.num_workers,
                                        shuffle=True,
                                        collate_fn=detection_collate,
                                        pin_memory=True)
    val_data_loader = data.DataLoader(val_dataset,
                                      args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=False,
                                      collate_fn=detection_collate,
                                      pin_memory=True)
    itr_count = 0
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for iteration in range(args.max_iter + 1):
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(train_data_loader)

        # load train data
        images, targets, img_indexs = next(batch_iterator)
        if args.cuda:
            images = Variable(images.cuda())
            targets = [
                Variable(anno.cuda(), volatile=True) for anno in targets
            ]
        else:
            images = Variable(images)
            targets = [Variable(anno, volatile=True) for anno in targets]
        # forward
        out = net(images)
        # backprop
        optimizer.zero_grad()

        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c
        loss.backward()
        optimizer.step()
        scheduler.step()
        loc_loss = loss_l.data[0]
        conf_loss = loss_c.data[0]
        # print('Loss data type ',type(loc_loss))
        loc_losses.update(loc_loss)
        cls_losses.update(conf_loss)
        losses.update((loc_loss + conf_loss) / 2.0)

        if iteration % args.print_step == 0 and iteration > 0:
            if args.visdom:
                losses_list = [
                    loc_losses.val, cls_losses.val, losses.val, loc_losses.avg,
                    cls_losses.avg, losses.avg
                ]
                viz.line(X=torch.ones((1, 6)).cpu() * iteration,
                         Y=torch.from_numpy(
                             np.asarray(losses_list)).unsqueeze(0).cpu(),
                         win=lot,
                         update='append')

            torch.cuda.synchronize()
            t1 = time.perf_counter()
            batch_time.update(t1 - t0)

            print_line = 'Itration {:06d}/{:06d} loc-loss {:.3f}({:.3f}) cls-loss {:.3f}({:.3f}) ' \
                         'average-loss {:.3f}({:.3f}) Timer {:0.3f}({:0.3f})'.format(
                          iteration, args.max_iter, loc_losses.val, loc_losses.avg, cls_losses.val,
                          cls_losses.avg, losses.val, losses.avg, batch_time.val, batch_time.avg)

            torch.cuda.synchronize()
            t0 = time.perf_counter()
            log_file.write(print_line + '\n')
            print(print_line)

            # if args.visdom and args.send_images_to_visdom:
            #     random_batch_index = np.random.randint(images.size(0))
            #     viz.image(images.data[random_batch_index].cpu().numpy())
            itr_count += 1

            if itr_count % args.loss_reset_step == 0 and itr_count > 0:
                loc_losses.reset()
                cls_losses.reset()
                losses.reset()
                batch_time.reset()
                print('Reset accumulators of ', args.exp_name, ' at',
                      itr_count * args.print_step)
                itr_count = 0

        if (iteration % args.eval_step == 0
                or iteration == 5000) and iteration > 0:
            torch.cuda.synchronize()
            tvs = time.perf_counter()
            print('Saving state, iter:', iteration)
            torch.save(
                net.state_dict(),
                args.save_root + 'ssd300_ucf24_' + repr(iteration) + '.pth')

            net.eval()  # switch net to evaluation mode
            mAP, ap_all, ap_strs = validate(args,
                                            net,
                                            val_data_loader,
                                            val_dataset,
                                            iteration,
                                            iou_thresh=args.iou_thresh)

            for ap_str in ap_strs:
                print(ap_str)
                log_file.write(ap_str + '\n')
            ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
            print(ptr_str)
            log_file.write(ptr_str)

            if args.visdom:
                aps = [mAP]
                for ap in ap_all:
                    aps.append(ap)
                viz.line(X=torch.ones((1, args.num_classes)).cpu() * iteration,
                         Y=torch.from_numpy(
                             np.asarray(aps)).unsqueeze(0).cpu(),
                         win=val_lot,
                         update='append')
            net.train()  # Switch net back to training mode
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0 - tvs)
            print(prt_str)
            log_file.write(ptr_str)

    log_file.close()
Exemplo n.º 25
0
def train():
    """ Train the model using the parameters defined in the config file """
    print('Initialising ...')
    cfg = TrainConfig()
    checkpoint_folder = 'checkpoints/{}/'.format(cfg.experiment_name)

    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder)

    tb_folder = 'tb/{}/'.format(cfg.experiment_name)
    if not os.path.exists(tb_folder):
        os.makedirs(tb_folder)

    writer = SummaryWriter(logdir=tb_folder, flush_secs=30)
    model = ParrotModel().cuda().train()
    optimiser = AdamW(model.parameters(),
                      lr=cfg.initial_lr,
                      weight_decay=cfg.weight_decay)

    train_dataset = ParrotDataset(cfg.train_labels, cfg.mp3_folder)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.batch_size,
                              num_workers=cfg.workers,
                              collate_fn=parrot_collate_function,
                              pin_memory=True)

    val_dataset = ParrotDataset(cfg.val_labels, cfg.mp3_folder)
    val_loader = DataLoader(val_dataset,
                            batch_size=cfg.batch_size,
                            num_workers=cfg.workers,
                            collate_fn=parrot_collate_function,
                            shuffle=False,
                            pin_memory=True)

    epochs = cfg.epochs
    init_loss, step = 0., 0
    avg_loss = AverageMeter()
    print('Starting training')
    for epoch in range(epochs):
        loader_length = len(train_loader)
        epoch_start = time.time()

        for batch_idx, batch in enumerate(train_loader):
            optimiser.zero_grad()

            # VRAM control by skipping long examples
            if batch['spectrograms'].shape[-1] > cfg.max_time:
                continue

            # inference
            target = batch['targets'].cuda()
            model_input = batch['spectrograms'].cuda()
            model_output = model(model_input)

            # loss
            input_lengths = batch['input_lengths'].cuda()
            target_lengths = batch['target_lengths'].cuda()
            loss = ctc_loss(model_output, target, input_lengths,
                            target_lengths)
            loss.backward()

            if epoch == 0 and batch_idx == 0:
                init_loss = loss

            # logging
            elapsed = time.time() - epoch_start
            progress = batch_idx / loader_length
            est = datetime.timedelta(
                seconds=int(elapsed / progress)) if progress > 0.001 else '-'
            avg_loss.update(loss)
            suffix = '\tloss {:.4f}/{:.4f}\tETA [{}/{}]'.format(
                avg_loss.avg, init_loss,
                datetime.timedelta(seconds=int(elapsed)), est)
            printProgressBar(batch_idx,
                             loader_length,
                             suffix=suffix,
                             prefix='Epoch [{}/{}]\tStep [{}/{}]'.format(
                                 epoch, epochs, batch_idx, loader_length))

            writer.add_scalar('Steps/train_loss', loss, step)

            # saving the model
            if step % cfg.checkpoint_every == 0:
                test_name = '{}/test_epoch{}.mp3'.format(
                    checkpoint_folder, epoch)
                test_mp3_file(cfg.test_mp3, model, test_name)
                checkpoint_name = '{}/epoch_{}.pth'.format(
                    checkpoint_folder, epoch)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'epoch': epoch,
                        'batch_idx': loader_length,
                        'step': step,
                        'optimiser': optimiser.state_dict()
                    }, checkpoint_name)

            # validating
            if step % cfg.val_every == 0:
                val(model, val_loader, writer, step)
                model = model.train()

            step += 1
            optimiser.step()

        # end of epoch
        print('')
        writer.add_scalar('Epochs/train_loss', avg_loss.avg, epoch)
        avg_loss.reset()
        test_name = '{}/test_epoch{}.mp3'.format(checkpoint_folder, epoch)
        test_mp3_file(cfg.test_mp3, model, test_name)
        checkpoint_name = '{}/epoch_{}.pth'.format(checkpoint_folder, epoch)
        torch.save(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'batch_idx': loader_length,
                'step': step,
                'optimiser': optimiser.state_dict()
            }, checkpoint_name)

    # finished training
    writer.close()
    print('Training finished :)')
Exemplo n.º 26
0
def train():
    set_seed(seed=10)
    os.makedirs(args.save_root, exist_ok=True)

    # create model, optimizer and criterion
    model = SSD300(n_classes=len(label_map), device=device)
    biases = []
    not_biases = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if name.endswith('.bias'):
                biases.append(param)
            else:
                not_biases.append(param)
    model = model.to(device)
    optimizer = torch.optim.SGD(params=[{
        'params': biases,
        'lr': 2 * args.lr
    }, {
        'params': not_biases
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.resume is None:
        start_epoch = 0
    else:
        checkpoint = torch.load(args.resume, map_location=device)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    print(f'Training will start at epoch {start_epoch}.')

    criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy,
                             device=device,
                             alpha=args.alpha)
    criterion = criterion.to(device)
    '''
    scheduler = StepLR(optimizer=optimizer,
                       step_size=20,
                       gamma=0.5,
                       last_epoch=start_epoch - 1,
                       verbose=True)
    '''

    # load data
    transform = Transform(size=(300, 300), train=True)
    train_dataset = VOCDataset(root=args.data_root,
                               image_set=args.image_set,
                               transform=transform,
                               keep_difficult=True)
    train_loader = DataLoader(dataset=train_dataset,
                              collate_fn=collate_fn,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=True,
                              pin_memory=True)

    losses = AverageMeter()
    for epoch in range(start_epoch, args.num_epochs):
        # decay learning rate at particular epochs
        if epoch in [120, 140, 160]:
            adjust_learning_rate(optimizer, 0.1)

        # train model
        model.train()
        losses.reset()
        bar = tqdm(train_loader, desc='Train the model')
        for i, (images, bboxes, labels, _) in enumerate(bar):
            images = images.to(device)
            bboxes = [b.to(device) for b in bboxes]
            labels = [l.to(device) for l in labels]

            predicted_bboxes, predicted_scores = model(
                images)  # (N, 8732, 4), (N, 8732, num_classes)
            loss = criterion(predicted_bboxes, predicted_scores, bboxes,
                             labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.update(loss.item(), images.size(0))

            if i % args.print_freq == args.print_freq - 1:
                bar.write(f'Average Loss: {losses.avg:.4f}')

        bar.write(f'Epoch: [{epoch + 1}|{args.num_epochs}] '
                  f'Average Loss: {losses.avg:.4f}')
        # adjust learning rate
        # scheduler.step()

        # save model
        state_dict = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        save_path = os.path.join(args.save_root, 'ssd300.pth')
        torch.save(state_dict, save_path)

        if epoch % args.save_freq == args.save_freq - 1:
            shutil.copyfile(
                save_path,
                os.path.join(args.save_root, f'ssd300_epochs_{epoch + 1}.pth'))
Exemplo n.º 27
0
class Trainer(object):  # the most basic model
    def __init__(self, config, data_loader=None):
        self.config = config
        self.data_loader = data_loader  # needed for VAE

        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.optimizer = config.optimizer
        self.batch_size = config.batch_size

        self.diffLoss = L1Loss_mask()  # custom module

        self.valmin_iter = 0
        self.model_dir = 'logs/' + str(config.expnum)
        self.savename_G = ''
        self.decoder = GreedyDecoder(data_loader.labels)

        self.kt = 0  # used for Proportional Control Theory in BEGAN, initialized as 0
        self.lb = 0.001
        self.conv_measure = 0  # convergence measure

        self.dce_tr = AverageMeter()
        self.dce_val = AverageMeter()
        self.wer_tr = AverageMeter()
        self.cer_tr = AverageMeter()
        self.wer_val = AverageMeter()
        self.cer_val = AverageMeter()

        self.build_model()
        self.G.loss_stop = 100000
        #self.get_weight_statistic()

        if self.config.gpu >= 0:
            self.G.cuda()
            self.ASR.cuda()

        if len(self.config.load_path) > 0:
            self.load_model()

        if config.mode == 'train':
            self.logFile = open(self.model_dir + '/log.txt', 'w')

    def zero_grad_all(self):
        self.G.zero_grad()

    def build_model(self):
        print('initialize enhancement model')
        self.G = stackedBRNN(I=self.config.nFeat,
                             H=self.config.rnn_size,
                             L=self.config.rnn_layers,
                             rnn_type=supported_rnns[self.config.rnn_type])

        print('load pre-trained ASR model')
        package_ASR = torch.load(self.config.ASR_path,
                                 map_location=lambda storage, loc: storage)
        self.ASR = DeepSpeech.load_model_package(package_ASR)
        # Weight initialization is done inside the module

    def load_model(self):
        print("[*] Load models from {}...".format(self.load_path))
        postfix = '_valmin'
        paths = glob(os.path.join(self.load_path, 'G{}*.pth'.format(postfix)))
        paths.sort()

        if len(paths) == 0:
            print("[!] No checkpoint found in {}...".format(self.load_path))
            assert (0), 'checkpoint not avilable'

        idxes = [
            int(os.path.basename(path.split('.')[0].split('_')[-1]))
            for path in paths
        ]
        if self.config.start_iter < 0:
            self.config.start_iter = max(idxes)
            if (self.config.start_iter < 0):  # if still 0, then raise error
                raise Exception(
                    "start iter is still less than 0 --> probably try to load initial random model"
                )

        if self.config.gpu < 0:  #CPU
            map_location = lambda storage, loc: storage
        else:  # GPU
            map_location = None

        # Ver2
        print('Load models from ' + self.load_path + ', ITERATION = ' +
              str(self.config.start_iter))
        self.G.load_state_dict(
            torch.load('{}/G{}_{}.pth'.format(self.load_path[:-1], postfix,
                                              self.config.start_iter),
                       map_location=map_location))

        print("[*] Model loaded")

    def train(self):
        # Setting
        optimizer_g = torch.optim.Adam(self.G.parameters(),
                                       lr=self.config.lr,
                                       betas=(self.beta1, self.beta2),
                                       amsgrad=True)

        for iter in trange(self.config.start_iter, self.config.max_iter):
            # Train
            data_list = self.data_loader.next(cl_ny='ny', type='train')
            inputs, cleans, mask = _get_variable_nograd(
                data_list[0]), _get_variable_nograd(
                    data_list[1]), _get_variable_nograd(data_list[2])

            # forward
            outputs = self.G(inputs)
            dce, nElement = self.diffLoss(
                outputs, cleans, mask)  # already normalized inside function

            # backward
            self.zero_grad_all()
            dce.backward()
            optimizer_g.step()

            # log
            #pdb.set_trace()
            if (iter + 1) % self.config.log_iter == 0:
                str_loss = "[{}/{}] (train) DCE: {:.7f}".format(
                    iter, self.config.max_iter, dce.data[0])
                print(str_loss)
                self.logFile.write(str_loss + '\n')
                self.logFile.flush()

            if (iter + 1) % self.config.save_iter == 0:
                self.G.eval()
                # Measure performance on training subset
                self.dce_tr.reset()
                self.wer_tr.reset()
                self.cer_tr.reset()
                for _ in trange(0, len(self.data_loader.trsub_dl)):
                    data_list = self.data_loader.next(cl_ny='ny', type='trsub')
                    inputs, cleans, mask, targets, input_percentages, target_sizes = \
                        _get_variable_volatile(data_list[0]), _get_variable_volatile(data_list[1]), _get_variable_volatile(data_list[2]), \
                        data_list[3], data_list[4], data_list[5]

                    outputs = self.G(inputs)
                    dce, nElement = self.diffLoss(
                        outputs, cleans,
                        mask)  # already normalized inside function
                    self.dce_tr.update(dce.data[0], nElement)

                    # Greedy decodoing
                    wer, cer, nWord, nChar = self.greedy_decoding(
                        inputs, targets, input_percentages, target_sizes)
                    self.wer_tr.update(wer, nWord)
                    self.cer_tr.update(cer, nChar)

                str_loss = "[{}/{}] (training subset) DCE: {:.7f}".format(
                    iter, self.config.max_iter, self.dce_tr.avg)
                print(str_loss)
                self.logFile.write(str_loss + '\n')

                str_loss = "[{}/{}] (training subset) WER: {:.7f}, CER: {:.7f}".format(
                    iter, self.config.max_iter, self.wer_tr.avg * 100,
                    self.cer_tr.avg * 100)
                print(str_loss)
                self.logFile.write(str_loss + '\n')

                # Measure performance on validation data
                self.dce_val.reset()
                self.wer_val.reset()
                self.cer_val.reset()
                for _ in trange(0, len(self.data_loader.val_dl)):
                    data_list = self.data_loader.next(cl_ny='ny', type='val')
                    inputs, cleans, mask, targets, input_percentages, target_sizes = \
                        _get_variable_volatile(data_list[0]), _get_variable_volatile(data_list[1]), _get_variable_volatile(data_list[2]), \
                        data_list[3], data_list[4], data_list[5]

                    outputs = self.G(inputs)
                    dce, nElement = self.diffLoss(
                        outputs, cleans,
                        mask)  # already normalized inside function
                    self.dce_val.update(dce.data[0], nElement)

                    # Greedy decodoing
                    wer, cer, nWord, nChar = self.greedy_decoding(
                        inputs, targets, input_percentages, target_sizes)
                    self.wer_val.update(wer, nWord)
                    self.cer_val.update(cer, nChar)

                str_loss = "[{}/{}] (validation) DCE: {:.7f}".format(
                    iter, self.config.max_iter, self.dce_val.avg)
                print(str_loss)
                self.logFile.write(str_loss + '\n')

                str_loss = "[{}/{}] (validation) WER: {:.7f}, CER: {:.7f}".format(
                    iter, self.config.max_iter, self.wer_val.avg * 100,
                    self.cer_val.avg * 100)
                print(str_loss)
                self.logFile.write(str_loss + '\n')

                self.G.train()  # end of validation
                self.logFile.flush()

                # Save model
                if (len(self.savename_G) > 0):  # do not remove here
                    if os.path.exists(self.savename_G):
                        os.remove(self.savename_G)  # remove previous model
                self.savename_G = '{}/G_{}.pth'.format(self.model_dir, iter)
                torch.save(self.G.state_dict(), self.savename_G)

                if (self.G.loss_stop > self.wer_val.avg):
                    self.G.loss_stop = self.wer_val.avg
                    savename_G_valmin_prev = '{}/G_valmin_{}.pth'.format(
                        self.model_dir, self.valmin_iter)
                    if os.path.exists(savename_G_valmin_prev):
                        os.remove(
                            savename_G_valmin_prev)  # remove previous model

                    print('save model for this checkpoint')
                    savename_G_valmin = '{}/G_valmin_{}.pth'.format(
                        self.model_dir, iter)
                    copyfile(self.savename_G, savename_G_valmin)
                    self.valmin_iter = iter

    def greedy_decoding(self,
                        inputs,
                        targets,
                        input_percentages,
                        target_sizes,
                        transcript_prob=0.001):
        # unflatten targets
        split_targets = []
        offset = 0
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size

        # step 1) Decoding to get wer & cer
        enhanced = self.G(inputs)
        prob = self.ASR(enhanced)
        prob = prob.transpose(0, 1)
        T = prob.size(0)
        sizes = input_percentages.mul_(int(T)).int()

        decoded_output, _ = self.decoder.decode(prob.data, sizes)
        target_strings = self.decoder.convert_to_strings(split_targets)
        we, ce, total_word, total_char = 0, 0, 0, 0

        for x in range(len(target_strings)):
            decoding, reference = decoded_output[x][0], target_strings[x][0]
            nChar = len(reference)
            nWord = len(reference.split())
            we_i = self.decoder.wer(decoding, reference)
            ce_i = self.decoder.cer(decoding, reference)
            we += we_i
            ce += ce_i
            total_word += nWord
            total_char += nChar
            if (random.uniform(0, 1) < transcript_prob):
                print('reference = ' + reference)
                print('decoding = ' + decoding)
                print('wer = ' + str(we_i / float(nWord)) + ', cer = ' +
                      str(ce_i / float(nChar)))

        wer = we / total_word
        cer = ce / total_word

        return wer, cer, total_word, total_char
Exemplo n.º 28
0
class Trainer():
    def __init__(self, model, criterion, optimizer, opt, writer):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.losses = AverageMeter()
        self.writer = writer

    def train(self, trainloader, epoch, opt):
        self.data_time.reset()
        self.batch_time.reset()
        self.model.train()
        self.losses.reset()

        end = time.time()
        for i, data in enumerate(trainloader, 0):

            self.optimizer.zero_grad()

            xh, xi, xp, shifted_targets, eyes, names, eyes2, gcorrs = data
            xh = xh.cpu()
            xi = xi.cpu()
            xp = xp.cpu()
            shifted_targets = shifted_targets.cpu().squeeze()

            self.data_time.update(time.time() - end)

            outputs = self.model(xh, xi, xp)
            total_loss = self.criterion(outputs[0],
                                        shifted_targets[:, 0, :].max(1)[1])
            for j in range(1, len(outputs)):
                total_loss += self.criterion(
                    outputs[j], shifted_targets[:, j, :].max(1)[1])

            total_loss = total_loss / (len(outputs) * 1.0)

            total_loss.backward()
            self.optimizer.step()

            inputs_size = xh.size(0)
            self.losses.update(total_loss.item(), inputs_size)
            self.batch_time.update(time.time() - end)
            end = time.time()

            if i % opt.printfreq == 0 and opt.verbose:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.avg:.3f} ({batch_time.sum:.3f})\t'
                      'Data {data_time.avg:.3f} ({data_time.sum:.3f})\t'
                      'Loss {loss.avg:.3f}\t'.format(
                          epoch,
                          i,
                          len(trainloader),
                          batch_time=self.batch_time,
                          data_time=self.data_time,
                          loss=self.losses))

            sys.stdout.flush()

        self.writer.add_scalar('Train Loss', self.losses.avg, epoch)
        print('Train: [{0}]\t'
              'Time {batch_time.sum:.3f}\t'
              'Data {data_time.sum:.3f}\t'
              'Loss {loss.avg:.3f}\t'.format(epoch,
                                             batch_time=self.batch_time,
                                             data_time=self.data_time,
                                             loss=self.losses))
Exemplo n.º 29
0
class Validator():
    def __init__(self, model, criterion, opt, writer):

        self.model = model
        self.criterion = criterion
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.dist = AverageMeter()
        self.mindist = AverageMeter()
        self.writer = writer

    def validate(self, valloader, epoch, opt):

        self.model.eval()
        self.dist.reset()
        self.mindist.reset()
        self.data_time.reset()
        self.batch_time.reset()
        end = time.time()

        with torch.no_grad():
            for i, data in enumerate(valloader, 0):

                xh, xi, xp, targets, eyes, names, eyes2, ground_labels = data
                xh = xh.cpu()
                xi = xi.cpu()
                xp = xp.cpu()

                self.data_time.update(time.time() - end)
                outputs = self.model.predict(xh, xi, xp)

                pred_labels = outputs.max(1)[1]
                inputs_size = xh.size(0)

                distval = utils.euclid_dist(pred_labels.data.cpu(),
                                            ground_labels, inputs_size)
                # mindistval = utils.euclid_mindist(pred_labels.data.cpu(), ground_labels, inputs_size)

                self.dist.update(distval, inputs_size)
                #self.mindist.update(mindistval, inputs_size)
                self.batch_time.update(time.time() - end)
                end = time.time()

                if i % opt.printfreq == 0 and opt.verbose:
                    print('Epoch: [{0}][{1}/{2}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                          'Dist {dist.avg:.3f}\t'.format(
                              epoch,
                              i,
                              len(valloader),
                              batch_time=self.batch_time,
                              data_time=self.data_time,
                              dist=self.dist))

                sys.stdout.flush()

            self.writer.add_scalar('Val Dist', self.dist.avg, epoch)
            #self.writer.add_scalar('Val Min Dist', self.mindist.avg, epoch)

            print('Val: [{0}]\t'
                  'Time {batch_time.sum:.3f}\t'
                  'Data {data_time.sum:.3f}\t'
                  'Dist {dist.avg:.3f}\t'.format(epoch,
                                                 batch_time=self.batch_time,
                                                 data_time=self.data_time,
                                                 dist=self.dist))

        return self.dist.avg
Exemplo n.º 30
0
class Trainer(object):  # the most basic model
    def __init__(self, config, data_loader=None):
        if (config.w_minWvar > 0):
            config.minimize_W_var = True
            self.varLoss = var_mask()

        self.config = config
        self.data_loader = data_loader  # needed for VAE

        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.optimizer = config.optimizer
        self.batch_size = config.batch_size

        self.diffLoss = L1Loss_mask()  # custom module

        log_domain = False
        if (self.config.linear_to_mel):
            log_domain = True
        self.get_SNRout = get_SNRout(log_domain=log_domain)

        self.valmin_iter = 0
        self.model_dir = 'models/' + str(config.expnum)
        self.log_dir = 'logs_only/' + str(config.expnum)
        self.savename_G = ''
        self.decoder = GreedyDecoder(data_loader.labels)

        self.kt = 0  # used for Proportional Control Theory in BEGAN, initialized as 0
        self.lb = 0.001
        self.conv_measure = 0  # convergence measure

        self.dce_tr = AverageMeter()
        self.dce_val = AverageMeter()

        self.snrout_tr = AverageMeter()
        self.snrout_val = AverageMeter()
        self.snrimpv_tr = AverageMeter()
        self.snrimpv_val = AverageMeter()

        if (config.linear_to_mel):
            self.mel_basis = librosa.filters.mel(self.config.fs,
                                                 self.config.nFFT,
                                                 self.config.nMel)
            self.melF_to_linearFs = get_linearF_from_melF(self.mel_basis)
            self.STFT_to_LMFB = STFT_to_LMFB(self.mel_basis,
                                             window_change=False)
            self.mag2mfb = linearmag2mel(self.mel_basis)

        mel_basis_20ms = librosa.filters.mel(
            self.config.fs, 320, self.config.nMel
        )  # mel_basis will be used only for 20ms window spectrogram
        self.STFT_to_LMFB_20ms = STFT_to_LMFB(mel_basis_20ms,
                                              win_size=self.config.nFFT)

        self.F = int(self.config.nFFT / 2 + 1)

        self.build_model()
        self.G.loss_stop = 100000
        #self.get_weight_statistic()

        if self.config.gpu >= 0:
            self.G.cuda()

        if len(self.config.load_path) > 0:
            self.load_model()

        if config.mode == 'train':
            self.logFile = open(self.log_dir + '/log.txt', 'w')

    def zero_grad_all(self):
        self.G.zero_grad()

    def build_model(self):
        self.G = LineartoMel_real(F=self.F,
                                  melF_to_linearFs=self.melF_to_linearFs,
                                  nCH=self.config.nCH,
                                  w=self.config.convW,
                                  H=self.config.nMap_per_F,
                                  L=self.config.L_CNN,
                                  non_linear=self.config.non_linear,
                                  BN=self.config.complex_BN)  # 현재 사용중인 모델
        G_name = 'LineartoMel_real'

        print('initialized enhancement model as ' + G_name)
        nParam = count_parameters(self.G)
        print('# trainable parameters = ' + str(nParam))

    def load_model(self):
        print("[*] Load models from {}...".format(self.config.load_path))
        postfix = '_valmin'
        paths = glob(
            os.path.join(self.config.load_path, 'G{}*.pth'.format(postfix)))
        paths.sort()

        if len(paths) == 0:
            print("[!] No checkpoint found in {}...".format(self.load_path))
            assert (0), 'checkpoint not avilable'

        idxes = [
            int(os.path.basename(path.split('.')[0].split('_')[-1]))
            for path in paths
        ]
        if self.config.start_iter <= 0:
            self.config.start_iter = max(idxes)
            if (self.config.start_iter <= 0):  # if still 0, then raise error
                raise Exception(
                    "start iter is still less than 0 --> probably try to load initial random model"
                )

        if self.config.gpu < 0:  #CPU
            map_location = lambda storage, loc: storage
        else:  # GPU
            map_location = None

        # Ver2
        print('Load models from ' + self.config.load_path + ', ITERATION = ' +
              str(self.config.start_iter))
        self.G.load_state_dict(
            torch.load('{}/G{}_{}.pth'.format(self.config.load_path, postfix,
                                              self.config.start_iter),
                       map_location=map_location))

        print("[*] Model loaded")

    def train(self):
        # Setting
        optimizer_g = torch.optim.Adam(self.G.parameters(),
                                       lr=self.config.lr,
                                       betas=(self.beta1, self.beta2),
                                       amsgrad=True)

        for iter in trange(self.config.start_iter, self.config.max_iter):
            # Train
            data_list = self.data_loader.next(cl_ny='ny', type='train')
            inputs, cleans, mask = data_list[0], data_list[1], data_list[
                2]  # cleans: NxFxT, mask: Nx1xT

            if (len(data_list) >= 9):
                mixture_magnitude = data_list[7]
                mixture_phsdiff = data_list[8]
                inputs_augmented = torch.cat(
                    (torch.log(1 + mixture_magnitude), mixture_phsdiff), dim=2)
                mfb = self.mag2mfb(mixture_magnitude)
                cleans = self.STFT_to_LMFB(cleans)

            if (self.config.linear_to_mel):
                inputs = [_get_variable(inputs_augmented), _get_variable(mfb)]
            else:
                inputs = _get_variable(inputs)
            cleans = _get_variable(cleans)
            mask = _get_variable(mask)

            # forward
            outputs = self.G(
                inputs
            )  # forward(입력(=[log(magnitude) phase difference]-->출력(=log-mel-filterbank output))

            dce, nElement = self.diffLoss(
                outputs, cleans, mask)  # already normalized inside function
            if (self.config.loss_per_freq):
                if (iter + 1) % self.config.log_iter == 0:
                    for f in range(dce.size(0)):
                        str_loss = "[{}/{}] (train) DCE_{}: {:.7f}".format(
                            iter, self.config.max_iter, f, dce[f].sum().item())
                        self.logFile.write(str_loss + '\n')

                dce = dce.sum()  # sum up all the loss

            total_loss = dce

            # backward
            self.zero_grad_all()
            total_loss.backward()

            optimizer_g.step()

            # log
            #pdb.set_trace()
            if (iter + 1) % self.config.log_iter == 0:
                #pdb.set_trace()
                str_loss = "[{}/{}] (train) DCE: {:.7f}".format(
                    iter, self.config.max_iter, dce.item())
                print(str_loss)
                self.logFile.write(str_loss + '\n')

                SNRout = self.get_SNRout(outputs, cleans, mask)
                SNRout = SNRout.sum() / cleans.size(0)

                str_loss = "[{}/{}] (train) SNRout: {:.7f}".format(
                    iter, self.config.max_iter, SNRout.item())
                print(str_loss)
                self.logFile.write(str_loss + '\n')

                self.logFile.flush()

            if (iter + 1) % self.config.save_iter == 0:
                with torch.no_grad():
                    self.G.eval()
                    self.diffLoss.eval()
                    # Measure performance on training subset
                    self.dce_tr.reset()
                    self.snrout_tr.reset()
                    self.snrimpv_tr.reset()

                    for _ in trange(0, len(self.data_loader.trsub_dl)):
                        data_list = self.data_loader.next(cl_ny='ny',
                                                          type='trsub')
                        inputs, cleans, mask = data_list[0], data_list[
                            1], data_list[2]
                        if (len(data_list) >= 6):
                            targets, input_percentages, target_sizes = data_list[
                                3], data_list[4], data_list[5]
                            if (len(data_list) >= 7):
                                SNRin_1s = _get_variable(data_list[6])
                                if (len(data_list) >= 9):
                                    mixture_magnitude = data_list[7]
                                    mixture_phsdiff = data_list[8]
                                    inputs_augmented = torch.cat(
                                        (torch.log(1 + mixture_magnitude),
                                         mixture_phsdiff),
                                        dim=2)
                                    mfb = self.mag2mfb(mixture_magnitude)
                                    cleans = self.STFT_to_LMFB(cleans)

                        cleans, mask = _get_variable(cleans), _get_variable(
                            mask)

                        if (self.config.linear_to_mel):
                            inputs = [
                                _get_variable(inputs_augmented),
                                _get_variable(mfb)
                            ]
                        else:
                            inputs = _get_variable(inputs)

                        # Forward (of training subset)
                        outputs = self.G(inputs)

                        dce, nElement = self.diffLoss(
                            outputs, cleans,
                            mask)  # already normalized inside function
                        self.dce_tr.update(dce.item(), nElement)

                        SNRout = self.get_SNRout(outputs, cleans, mask)
                        SNRimprovement = SNRout - SNRin_1s
                        SNRout = SNRout.sum() / cleans.size(0)
                        SNRimprovement = SNRimprovement.sum() / cleans.size(0)

                        self.snrout_tr.update(SNRout.item(), cleans.size(0))
                        self.snrimpv_tr.update(SNRimprovement.item(),
                                               cleans.size(0))

                    str_loss = "[{}/{}] (training subset) DCE: {:.7f}".format(
                        iter, self.config.max_iter, self.dce_tr.avg)
                    print(str_loss)
                    self.logFile.write(str_loss + '\n')

                    str_loss = "[{}/{}] (training subset) SNRout: {:.7f}".format(
                        iter, self.config.max_iter, self.snrout_tr.avg)
                    print(str_loss)
                    self.logFile.write(str_loss + '\n')

                    str_loss = "[{}/{}] (training subset) SNRimprovement: {:.7f}".format(
                        iter, self.config.max_iter, self.snrimpv_tr.avg)
                    print(str_loss)
                    self.logFile.write(str_loss + '\n')

                    # Measure performance on validation data
                    self.dce_val.reset()
                    self.wer_val.reset()
                    self.cer_val.reset()
                    self.snrout_tr.reset()
                    self.snrimpv_tr.reset()

                    for _ in trange(0, len(self.data_loader.val_dl)):
                        data_list = self.data_loader.next(cl_ny='ny',
                                                          type='val')
                        inputs, cleans, mask = data_list[0], data_list[
                            1], data_list[2]
                        if (len(data_list) >= 6):
                            targets, input_percentages, target_sizes = data_list[
                                3], data_list[4], data_list[5]
                            if (len(data_list) >= 7):
                                SNRin_1s = _get_variable(data_list[6])
                                if (len(data_list) >= 9):
                                    mixture_magnitude = data_list[7]
                                    mixture_phsdiff = data_list[8]
                                    mfb = self.mag2mfb(mixture_magnitude)
                                    inputs_augmented = torch.cat(
                                        (torch.log(1 + mixture_magnitude),
                                         mixture_phsdiff),
                                        dim=2)
                                    cleans = self.STFT_to_LMFB(cleans)

                        cleans, mask = _get_variable(cleans), _get_variable(
                            mask)

                        if (self.config.linear_to_mel):
                            inputs = [
                                _get_variable(inputs_augmented),
                                _get_variable(mfb)
                            ]
                        else:
                            inputs = _get_variable(inputs)

                        # Forward (of validation)
                        outputs = self.G(inputs)

                        dce, nElement = self.diffLoss(
                            outputs, cleans,
                            mask)  # already normalized inside function

                        self.dce_val.update(dce.item(), nElement)

                        SNRout = self.get_SNRout(outputs, cleans, mask)
                        SNRimprovement = SNRout - SNRin_1s
                        SNRout = SNRout.sum() / cleans.size(0)
                        SNRimprovement = SNRimprovement.sum() / cleans.size(0)

                        self.snrout_val.update(SNRout.item(), cleans.size(0))
                        self.snrimpv_val.update(SNRimprovement.item(),
                                                cleans.size(0))

                    str_loss = "[{}/{}] (validation) DCE: {:.7f}".format(
                        iter, self.config.max_iter, self.dce_val.avg)
                    print(str_loss)
                    self.logFile.write(str_loss + '\n')

                    str_loss = "[{}/{}] (validation) SNRout: {:.7f}".format(
                        iter, self.config.max_iter, self.snrout_val.avg)
                    print(str_loss)
                    self.logFile.write(str_loss + '\n')

                    str_loss = "[{}/{}] (validation) SNRimprovement: {:.7f}".format(
                        iter, self.config.max_iter, self.snrimpv_val.avg)
                    print(str_loss)
                    self.logFile.write(str_loss + '\n')

                    self.G.train()  # end of validation
                    self.diffLoss.train()
                    self.logFile.flush()

                    # Save model
                    if (len(self.savename_G) > 0):  # do not remove here
                        if os.path.exists(self.savename_G):
                            os.remove(self.savename_G)  # remove previous model
                    self.savename_G = '{}/G_{}.pth'.format(
                        self.model_dir, iter)
                    torch.save(self.G.state_dict(), self.savename_G)

                    if (self.G.loss_stop > self.wer_val.avg):
                        self.G.loss_stop = self.wer_val.avg
                        savename_G_valmin_prev = '{}/G_valmin_{}.pth'.format(
                            self.model_dir, self.valmin_iter)
                        if os.path.exists(savename_G_valmin_prev):
                            os.remove(savename_G_valmin_prev
                                      )  # remove previous model

                        print('save model for this checkpoint')
                        savename_G_valmin = '{}/G_valmin_{}.pth'.format(
                            self.model_dir, iter)
                        copyfile(self.savename_G, savename_G_valmin)
                        self.valmin_iter = iter