Exemplo n.º 1
0
def train_network(start_epoch, epochs, scheduler, model, train_loader,
                  val_loader, optimizer, criterion, device, dtype, batch_size,
                  log_interval, csv_logger, save_path, claimed_acc1,
                  claimed_acc5, best_test):
    for epoch in trange(start_epoch, epochs + 1):
        if not isinstance(scheduler, CyclicLR):
            scheduler.step()
        train_loss = train(model, train_loader, epoch, optimizer, criterion,
                           device, dtype, batch_size, log_interval, scheduler)
        test_loss = test(model, val_loader, criterion, device, dtype)
        csv_logger.write({
            'epoch': epoch + 1,
            'val_loss': test_loss,
            'train_loss': train_loss
        })
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_test,
                'optimizer': optimizer.state_dict()
            },
            test_loss < best_test,
            filepath=save_path)

        if test_loss < best_test:
            best_test = test_loss

    csv_logger.write_text('Best loss is {}'.format(best_test))
Exemplo n.º 2
0
def train_network(start_epoch, epochs, scheduler, model, train_loader, val_loader, test_loader, optimizer, criterion,
                  device, dtype,
                  batch_size, log_interval, csv_logger, save_path, best_val):
    for epoch in trange(start_epoch, epochs + 1):
        if not isinstance(scheduler, CyclicLR):
            scheduler.step()
        train_loss, train_mae, = train(model, train_loader, epoch, optimizer, criterion, device,
                                       dtype, batch_size, log_interval, scheduler)
        val_loss, val_mae = test(model, val_loader, criterion, device, dtype)
        test_loss, test_mae = test(model, test_loader, criterion, device, dtype)
        csv_logger.write({'epoch': epoch + 1, 'test_mae': test_mae,
                          'test_loss': test_loss, 'val_mae': val_mae,
                          'val_loss': val_loss, 'train_mae': train_mae,
                          'train_loss': train_loss})
        save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_val,
                         'optimizer': optimizer.state_dict()}, val_mae < best_val, filepath=save_path)

        # csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5)

        csv_logger.plot_progress()

        if val_mae < best_val:
            best_val = val_mae

    csv_logger.write_text('Lowest mae is {:.2f}'.format(best_val))
Exemplo n.º 3
0
def train_network(start_epoch, epochs, scheduler, model, train_loader,
                  val_loader, optimizer, criterion, device, dtype, batch_size,
                  log_interval, csv_logger, save_path, claimed_acc1,
                  claimed_acc5, best_test):
    for epoch in trange(start_epoch, epochs + 1):
        if not isinstance(scheduler, CyclicLR):
            scheduler.step()
        train_loss, train_accuracy1, train_accuracy5, = train(
            model, train_loader, epoch, optimizer, criterion, device, dtype,
            batch_size, log_interval, scheduler)
        test_loss, test_accuracy1, test_accuracy5 = test(
            model, val_loader, criterion, device, dtype)
        csv_logger.write({
            'epoch': epoch + 1,
            'val_error1': 1 - test_accuracy1,
            'val_error5': 1 - test_accuracy5,
            'val_loss': test_loss,
            'train_error1': 1 - train_accuracy1,
            'train_error5': 1 - train_accuracy5,
            'train_loss': train_loss
        })
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_test,
                'optimizer': optimizer.state_dict()
            },
            test_accuracy1 > best_test,
            filepath=save_path)

        csv_logger.plot_progress(claimed_acc1=claimed_acc1,
                                 claimed_acc5=claimed_acc5)

        if test_accuracy1 > best_test:
            best_test = test_accuracy1
        print('Best accuracy is {:.2f}% top-1'.format(best_test * 100.))
        temp_dict = {
            'epoch': epoch + 1,
            'val_error1': 1 - test_accuracy1,
            'val_error5': 1 - test_accuracy5,
            'val_loss': test_loss,
            'train_error1': 1 - train_accuracy1,
            'train_error5': 1 - train_accuracy5,
            'train_loss': train_loss
        }
        for x in temp_dict:
            print(x, ":", temp_dict[x])

    csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test *
                                                                  100.))
Exemplo n.º 4
0
def train_network(start_epoch, epochs, scheduler, model, train_loader,
                  val_loader, adv_data, optimizer, criterion, device, dtype,
                  batch_size, log_interval, csv_logger, save_path,
                  claimed_acc1, claimed_acc5, best_test):
    for epoch in trange(start_epoch, epochs + 1):
        train_loss, train_accuracy1, train_accuracy5, = train(
            model, train_loader, epoch, optimizer, criterion, device, dtype,
            batch_size, log_interval)
        if adv_data is not None:
            traina_loss, traina_accuracy1, traina_accuracy5, = train(
                model, adv_data, epoch, optimizer, criterion, device, dtype,
                batch_size, log_interval)
        test_loss, test_accuracy1, test_accuracy5 = test(
            model, val_loader, criterion, device, dtype)
        csv_logger.write({
            'epoch': epoch + 1,
            'val_error1': 1 - test_accuracy1,
            'val_error5': 1 - test_accuracy5,
            'val_loss': test_loss,
            'train_error1': 1 - train_accuracy1,
            'train_error5': 1 - train_accuracy5,
            'train_loss': train_loss
        })
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_test,
                'optimizer': optimizer.state_dict()
            },
            test_accuracy1 > best_test,
            filepath=save_path)

        csv_logger.plot_progress(claimed_acc1=claimed_acc1,
                                 claimed_acc5=claimed_acc5)

        if test_accuracy1 > best_test:
            best_test = test_accuracy1
        for layer in model.modules():
            from layers import NoisedConv2D, NoisedLinear
            if isinstance(layer, NoisedConv2D) or isinstance(
                    layer, NoisedLinear):
                print("Mean of alphas is {}".format(torch.mean(layer.alpha)))
        scheduler.step()

    csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test *
                                                                  100.))
Exemplo n.º 5
0
def adv_train_network_alpha(start_epoch, epochs, scheduler, model,
                            train_loader, val_loader, optimizer, criterion,
                            device, dtype, batch_size, log_interval,
                            csv_logger, save_path, claimed_acc1, claimed_acc5,
                            best_test, adv_method, eps, adv_w, normalize):
    # alpha_sched = np.concatenate((np.ones(epochs // 8), np.linspace(1, 0, epochs - 2 * (epochs // 8)), np.zeros(epochs // 8)))
    alpha_sched = np.concatenate(
        (np.ones(epochs // 8), np.logspace(0, -4, epochs - 2 * (epochs // 8)),
         np.zeros(epochs // 8 + 20)))
    for epoch in trange(start_epoch, epochs + 1):
        model.set_alpha(alpha_sched[epoch])
        tqdm.write("alpha={}".format(alpha_sched[epoch]))
        train_loss, train_accuracy1, train_accuracy5, = adv_train(
            model, train_loader, epoch, optimizer, criterion, device, dtype,
            batch_size, log_interval, adv_method, eps, adv_w, normalize, 0.05,
            True, alpha_sched[epoch], alpha_sched[epoch + 1])
        test_loss, test_accuracy1, test_accuracy5 = test(
            model, val_loader, criterion, device, dtype)
        csv_logger.write({
            'epoch': epoch + 1,
            'val_error1': 1 - test_accuracy1,
            'val_error5': 1 - test_accuracy5,
            'val_loss': test_loss,
            'train_error1': 1 - train_accuracy1,
            'train_error5': 1 - train_accuracy5,
            'train_loss': train_loss
        })
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_test,
                'optimizer': optimizer.state_dict()
            },
            test_accuracy1 > best_test,
            filepath=save_path)

        csv_logger.plot_progress(claimed_acc1=claimed_acc1,
                                 claimed_acc5=claimed_acc5)

        if test_accuracy1 > best_test:
            best_test = test_accuracy1
        scheduler.step()

    csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test *
                                                                  100.))
Exemplo n.º 6
0
def train_network(start_epoch, epochs, optim, model, train_loader, val_loader,
                  criterion, mixup, device, dtype, batch_size, log_interval,
                  csv_logger, save_path, claimed_acc1, claimed_acc5, best_test,
                  local_rank, child):
    my_range = range if child else trange
    for epoch in my_range(start_epoch, epochs + 1):
        if not isinstance(optim.scheduler, CyclicLR) and not isinstance(
                optim.scheduler, CosineLR):
            optim.scheduler_step()
        train_loss, train_accuracy1, train_accuracy5, = train(
            model, train_loader, mixup, epoch, optim, criterion, device, dtype,
            batch_size, log_interval, child)
        test_loss, test_accuracy1, test_accuracy5 = test(
            model, val_loader, criterion, device, dtype, child)
        csv_logger.write({
            'epoch': epoch + 1,
            'val_error1': 1 - test_accuracy1,
            'val_error5': 1 - test_accuracy5,
            'val_loss': test_loss,
            'train_error1': 1 - train_accuracy1,
            'train_error5': 1 - train_accuracy5,
            'train_loss': train_loss
        })
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_test,
                'optimizer': optim.state_dict()
            },
            test_accuracy1 > best_test,
            filepath=save_path,
            local_rank=local_rank)

        csv_logger.plot_progress(claimed_acc1=claimed_acc1,
                                 claimed_acc5=claimed_acc5)

        if test_accuracy1 > best_test:
            best_test = test_accuracy1

    csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test *
                                                                  100.))
Exemplo n.º 7
0
def train_network(start_epoch, epochs, scheduler, model, train_loader, val_loader, optimizer, criterion, device, dtype,
                  batch_size, log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, best_test):
    for epoch in trange(start_epoch, epochs + 1):
        if not isinstance(scheduler, CyclicLR):
            scheduler.step()
        train_loss, train_accuracy1, train_accuracy5, = train(model, train_loader, epoch, optimizer, criterion, device,
                                                              dtype, batch_size, log_interval, scheduler)
        test_loss, test_accuracy1, test_accuracy5 = test(model, val_loader, criterion, device, dtype)
        csv_logger.write({'epoch': epoch + 1, 'val_error1': 1 - test_accuracy1, 'val_error5': 1 - test_accuracy5,
                          'val_loss': test_loss, 'train_error1': 1 - train_accuracy1,
                          'train_error5': 1 - train_accuracy5, 'train_loss': train_loss})
        save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_test,
                         'optimizer': optimizer.state_dict()}, test_accuracy1 > best_test, filepath=save_path)

        csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5)

        if test_accuracy1 > best_test:
            best_test = test_accuracy1

    csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test * 100.))
Exemplo n.º 8
0
def adv_train_network(start_epoch,
                      epochs,
                      scheduler,
                      model,
                      train_loader,
                      val_loader,
                      optimizer,
                      criterion,
                      device,
                      dtype,
                      batch_size,
                      log_interval,
                      csv_logger,
                      save_path,
                      claimed_acc1,
                      claimed_acc5,
                      best_test,
                      adv_method,
                      eps,
                      adv_w,
                      normalize,
                      args,
                      subts_args=None):
    att_object = adv_method(model, criterion)
    for epoch in trange(start_epoch, epochs + 1):

        train_loss, train_accuracy1, train_accuracy5, = adv_train(
            model, train_loader, epoch, optimizer, criterion, device, dtype,
            batch_size, log_interval, att_object, eps, adv_w, normalize, 0.05)
        test_loss, test_accuracy1, test_accuracy5 = test(
            model, val_loader, criterion, device, dtype)
        csv_logger.write({
            'epoch': epoch + 1,
            'val_error1': 1 - test_accuracy1,
            'val_error5': 1 - test_accuracy5,
            'val_loss': test_loss,
            'train_error1': 1 - train_accuracy1,
            'train_error5': 1 - train_accuracy5,
            'train_loss': train_loss
        })
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_test,
                'optimizer': optimizer.state_dict()
            },
            test_accuracy1 > best_test,
            filepath=save_path)

        csv_logger.plot_progress(claimed_acc1=claimed_acc1,
                                 claimed_acc5=claimed_acc5)

        if test_accuracy1 > best_test:
            best_test = test_accuracy1
        for layer in model.modules():
            from layers import NoisedConv2D, NoisedLinear, NoisedConv2DColored
            if isinstance(layer, NoisedConv2D) or isinstance(
                    layer, NoisedLinear):
                tqdm.write("Mean of alphas is {}".format(
                    torch.mean(layer.alpha)))
            if isinstance(layer, NoisedConv2DColored):
                try:

                    tqdm.write("Mean of alphas_diag_w is {}+-{} ({}) ".format(
                        torch.mean(torch.abs(layer.alphad_w)),
                        torch.std(torch.abs(layer.alphad_w)),
                        torch.max(torch.abs(layer.alphad_w))))
                    tqdm.write(
                        "Mean of alphas_factor_w is {}+-{} ({}) ".format(
                            torch.mean(torch.abs(layer.alphaf_w)),
                            torch.std(layer.alphaf_w),
                            torch.max(torch.abs(layer.alphaf_w))))
                except:
                    pass

                try:
                    tqdm.write("Mean of alphas_diag_a is {}+-{} ({})  ".format(
                        torch.mean(torch.abs(layer.alphad_i)),
                        torch.std(torch.abs(layer.alphad_i)),
                        torch.max(torch.abs(layer.alphad_i))))
                    tqdm.write(
                        "Mean of alphas_factor_a is {}+-{} ({}) ".format(
                            torch.mean(torch.abs(layer.alphaf_i)),
                            torch.std(layer.alphaf_i),
                            torch.max(torch.abs(layer.alphaf_i))))
                except:
                    pass
        scheduler.step()

    csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test *
                                                                  100.))
Exemplo n.º 9
0
def process(rank, world_size, train_pairs, test_pairs, resume):

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    device = rank

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_pairs, num_replicas=world_size, rank=rank, shuffle=False)

    # dataset_train = DataLoader(train_pairs, batch_size=BATCH_SIZE,
    #                            shuffle=True, num_workers=NUM_WORKERS,
    #                            collate_fn=collate_function,
    #                            pin_memory=True)

    dataset_train = DataLoader(train_pairs,
                               batch_size=BATCH_SIZE,
                               shuffle=False,
                               num_workers=NUM_WORKERS,
                               collate_fn=collate_function,
                               pin_memory=True,
                               sampler=train_sampler)

    dataset_test = DataLoader(test_pairs,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=NUM_WORKERS,
                              collate_fn=collate_function,
                              drop_last=True,
                              pin_memory=True)

    model = TransformerSTT(**model_parameters)
    # model = nn.DataParallel(model)
    model = model.to(device)
    model = DDP(model, find_unused_parameters=True, device_ids=[rank])
    # print(str(model))
    learning_rate = LR
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_criterion = nn.CTCLoss(zero_infinity=True)
    train_step = 0

    model, optimizer, train_step, writer = resume_training(
        resume, model, optimizer, train_step, rank)

    scaler = GradScaler()

    loss_list = list()
    wer_list = list()

    for epoch in range(NUM_EPOCH):
        model.train()
        for data in tqdm(dataset_train):
            mel_tensor, jamo_code_tensor, mel_lengths, jamo_lengths, mel_transformer_mask, speakers = data

            # speaker_code = speaker_table.speaker_name_to_code(speakers)

            with autocast():
                output_tensor = model((
                    mel_tensor.to(device),
                    mel_transformer_mask.to(device),
                ))

                output_tensor = output_tensor.permute(
                    1, 0, 2)  # (N, S, E) => (T, N, C)

                loss = loss_criterion(output_tensor,
                                      jamo_code_tensor.to(device),
                                      (mel_lengths // 8).to(device),
                                      jamo_lengths.to(device))

            optimizer.zero_grad()
            # loss.backward()
            # optimizer.step()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_step += 1

            if rank == 0:

                decoded_input_text = KOREAN_TABLE.decode_jamo_code_tensor(
                    jamo_code_tensor)
                decoded_input_text = KOREAN_TABLE.decode_ctc_prediction(
                    decoded_input_text)
                decoded_output_text = KOREAN_TABLE.decode_jamo_prediction_tensor(
                    output_tensor)
                decoded_output_str = KOREAN_TABLE.decode_ctc_prediction(
                    decoded_output_text)

                wer = KOREAN_TABLE.caculate_wer(decoded_input_text,
                                                decoded_output_str)
                wer_list.append(wer)
                loss_list.append(loss.item())

                if len(loss_list) >= LOGGING_STEPS:
                    writer.add_scalar('ctc_loss/train', np.mean(loss_list),
                                      train_step)
                    decoded_pairs =  [f'** {in_text} \n\n -> {out_text} \n\n => {final_output} \n\n' \
                                    for (in_text, out_text, final_output) in zip(decoded_input_text, decoded_output_text, decoded_output_str)]
                    writer.add_text('text_result/train',
                                    '\n\n'.join(decoded_pairs), train_step)
                    writer.add_scalar('WER/train', np.mean(wer_list),
                                      train_step)
                    logging_image = mel_tensor_to_plt_image(
                        mel_tensor, decoded_input_text, train_step)
                    writer.add_image('input_spectrogram/train', logging_image,
                                     train_step)
                    print(f'Train Step {train_step}')
                    loss_list = list()
                    wer_list = list()

                if train_step % CHECKPOINT_STEPS == 0:
                    save_checkpoint(model, optimizer, train_step,
                                    writer.logdir, KEEP_LAST_ONLY)

            # break

        if rank == 0:

            loss_test_list = list()
            wer_test_list = list()

            model.eval()
            for data in tqdm(dataset_test):
                mel_tensor, jamo_code_tensor, mel_lengths, jamo_lengths, mel_transformer_mask, speakers = data

                with autocast():
                    output_tensor = model((
                        mel_tensor.to(device),
                        mel_transformer_mask.to(device),
                    ))

                    output_tensor = output_tensor.permute(
                        1, 0, 2)  # (N, S, E) => (T, N, C)

                    loss = loss_criterion(output_tensor,
                                          jamo_code_tensor.to(device),
                                          (mel_lengths // 8).to(device),
                                          jamo_lengths.to(device))

                loss_test_list.append(loss.item())

                decoded_input_text = KOREAN_TABLE.decode_jamo_code_tensor(
                    jamo_code_tensor)
                decoded_input_text = KOREAN_TABLE.decode_ctc_prediction(
                    decoded_input_text)
                decoded_output_text = KOREAN_TABLE.decode_jamo_prediction_tensor(
                    output_tensor)
                decoded_output_str = KOREAN_TABLE.decode_ctc_prediction(
                    decoded_output_text)
                wer = KOREAN_TABLE.caculate_wer(decoded_input_text,
                                                decoded_output_str)
                wer_test_list.append(wer)

            decoded_pairs =  [f'** {in_text} \n\n -> {out_text} \n\n => {final_output} \n\n' \
                        for (in_text, out_text, final_output) in zip(decoded_input_text, decoded_output_text, decoded_output_str)]
            writer.add_scalar('ctc_loss/test', np.mean(loss_test_list),
                              train_step)
            writer.add_scalar('WER/test', np.mean(wer_test_list), train_step)
            writer.add_text('text_result/test', '\n\n'.join(decoded_pairs),
                            train_step)
            logging_image = mel_tensor_to_plt_image(mel_tensor,
                                                    decoded_input_text,
                                                    train_step)
            writer.add_image('input_spectrogram/test', logging_image,
                             train_step)
    def train_network(self, resume=False):
        csv_logger = CsvLogger(filepath=self.opt.save_model)
        if (not resume):
            best_test = 0

            start_epoch = self.opt.start_epoch
        else:
            checkpoint_path = os.path.join(
                self.opt.resume,
                'checkpoint{}.pth.tar'.format(self.opt.local_rank))
            csv_path = os.path.join(
                self.opt.resume, 'results{}.csv'.format(self.opt.local_rank))
            print("=> loading checkpoint '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path,
                                    map_location=self.opt.device)
            start_epoch = checkpoint['epoch']
            start_step = len(self.train_loader) * start_epoch
            self.optim, self.mixup = self.init_optimizer_and_mixup(
                checkpoint['optimizer'])
            best_test = checkpoint['best_prec1']
            self.model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_path, checkpoint['epoch']))

        for epoch in range(start_epoch, self.opt.epochs + 1):
            print("-------------------------", epoch,
                  "--------------------------------")

            train_loss, train_accuracy1, train_accuracy5, = run_train(
                self.model,
                self.train_loader,
                self.mixup,
                epoch,
                self.optim,
                self.criterion,
                self.device,
                self.opt._dtype,
                self.opt.train_batch_size,
                log_interval=2)
            test_loss, test_accuracy1, test_accuracy5 = run_test(
                self.model, self.val_loader, self.criterion, self.device,
                self.opt._dtype)

            self.optim.epoch_step()
            self.vis.plot_curves({'train_loss': train_loss},
                                 iters=epoch,
                                 title='train loss',
                                 xlabel='epoch',
                                 ylabel='train loss')
            self.vis.plot_curves({'train_acc': train_accuracy1},
                                 iters=epoch,
                                 title='train acc',
                                 xlabel='epoch',
                                 ylabel='train acc')
            self.vis.plot_curves({'val_loss': test_loss},
                                 iters=epoch,
                                 title='val loss',
                                 xlabel='epoch',
                                 ylabel='val loss')
            self.vis.plot_curves({'val_acc': test_accuracy1},
                                 iters=epoch,
                                 title='val acc',
                                 xlabel='epoch',
                                 ylabel='val acc')

            csv_logger.write({
                'epoch': epoch + 1,
                'val_error1': 1 - test_accuracy1,
                'val_error5': 1 - test_accuracy5,
                'val_loss': test_loss,
                'train_error1': 1 - train_accuracy1,
                'train_error5': 1 - train_accuracy5,
                'train_loss': train_loss
            })
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'best_prec1': best_test,
                    'optimizer': self.optim.state_dict()
                },
                test_accuracy1 > best_test,
                filepath=self.opt.save_model,
                local_rank=self.opt.local_rank)
            # TODO: save on the end of the cycle

            mem = '%.3gG' % (torch.cuda.memory_cached() /
                             1E9 if torch.cuda.is_available() else 0)
            print("memory gpu use : ", mem)
            if test_accuracy1 > best_test:
                best_test = test_accuracy1