Example #1
0
def main():
    # for training
    bg_train_loader, bg_val_loader, fg_train_loader, fg_val_loader = get_loaders(
    )
    # for testing
    sirr_test_loader = test_loaders()
    # for validation
    raw_val_loader, edgeR_val_loader, edgeB_val_loader = val_loaders()

    vgg16_net, deblend_net, Inp_net, Dis_net = build_networks()
    deblend_optimizer, Inp_optimizer, Dis_optimizer = get_optimizers(
        deblend_net, Inp_net, Dis_net)

    #############################################################################
    #
    # you can pre-generate the mask!
    #
    #############################################################################
    # mask_patch = blend.build_masks()
    # torch.save(mask_patch, './Mask/mask.pth')
    mask_loader = torch.load('./Mask/mask.pth')
    print('load mask success!')

    # make ckpt dir
    if not os.path.exists(args['chkpt_root']):
        os.makedirs(args['chkpt_root'])

    #############################################################################
    #
    # train & validate
    #
    #############################################################################
    for epoch in range(args['start_epoch'], args['num_epochs']):
        print('epoch: {}/{}'.format(epoch, args['num_epochs']))
        train(bg_train_loader, fg_train_loader, mask_loader, vgg16_net,
              deblend_net, Inp_net, Dis_net, deblend_optimizer, Inp_optimizer,
              Dis_optimizer, epoch)

        if epoch % args['save_frequency'] == 0:
            image_names = ('x1', 'x2', 'x', 'y1', 'y2', 'y3')
            image_samples = sample(bg_val_loader, fg_val_loader,
                                   raw_val_loader, edgeR_val_loader,
                                   edgeB_val_loader, sirr_test_loader,
                                   mask_loader, deblend_net, Inp_net, epoch)
            common.save_checkpoint(args['chkpt_root'], epoch, image_samples,
                                   image_names, deblend_net, Inp_net, Dis_net)
Example #2
0
def main(args):
    """Starting point"""
    logging.basicConfig(format='%(asctime)s -- %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

    model_arch = args['--arch']
    logging.info('Setting up %s model...', model_arch)
    image_datasets, dataloaders = common.load_datasets(args['DATA_DIRECTORY'])
    device = torch.device('cuda' if args['--gpu'] else 'cpu')

    hidden_units = int(args['--hidden_units'])
    dropout = float(args['--dropout'])
    epochs = int(args['--epochs'])
    learning_rate = float(args['--learning_rate'])

    nn_model = common.get_pre_trained_model(model_arch)
    logging.info('Original pre-trained classifier: %s', nn_model.classifier)
    num_ftrs = get_classifier_input_size(
        nn_model) if args['--input_units'] is None else int(
            args['--input_units'])
    num_classes = len(image_datasets['train'].classes)
    nn_model.classifier = common.get_classifier(num_ftrs, hidden_units,
                                                num_classes, dropout)
    nn_model.class_to_idx = image_datasets['train'].class_to_idx

    logging.info(f'Running training loop...')
    optimizer = optim.Adam(nn_model.classifier.parameters(), lr=learning_rate)
    train_nn_model(nn_model, optimizer, dataloaders['train'],
                   dataloaders['valid'], device, epochs)

    checkpoint = os.path.join(args['--save_dir'], 'checkpoint.pth')
    logging.info(f'Saving checkpoint in "%s"...', checkpoint)
    common.save_checkpoint(
        checkpoint, model_arch, nn_model, {
            'input_units': num_ftrs,
            'hidden_units': hidden_units,
            'output_units': num_classes,
            'dropout': dropout,
        })

    logging.info(f'Running model over test set...')
    test_model(nn_model, dataloaders['test'], device)

    logging.info('My job is done, going gently into that good night!')
    return 0
Example #3
0
    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        top5 = AverageMeter()
        global best_pred, acclist_train, acclist_val
        is_best = False
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            with torch.no_grad():
                output = model(data)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))

        # sum all
        sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count)

        if args.eval:
            if args.gpu == 0:
                top1_acc = sum(sum1) / sum(cnt1)
                top5_acc = sum(sum5) / sum(cnt5)
                print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))
            return

        if args.gpu == 0:
            top1_acc = sum(sum1) / sum(cnt1)
            top5_acc = sum(sum5) / sum(cnt5)
            print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))

            # save checkpoint
            acclist_val += [top1_acc]
            if top1_acc > best_pred:
                best_pred = top1_acc 
                is_best = True
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train':acclist_train,
                'acclist_val':acclist_val,
                }, args=args, is_best=is_best)
Example #4
0
def main():
    args = get_args()
    setup_logger('{}/log-train'.format(args.dir), args.log_level)
    logging.info(' '.join(sys.argv))

    if torch.cuda.is_available() == False:
        logging.error('No GPU detected!')
        sys.exit(-1)

    # WARNING(fangjun): we have to select GPU at the very
    # beginning; otherwise you will get trouble later
    kaldi.SelectGpuDevice(device_id=args.device_id)
    kaldi.CuDeviceAllowMultithreading()
    device = torch.device('cuda', args.device_id)

    den_fst = fst.StdVectorFst.Read(args.den_fst_filename)

    # TODO(fangjun): pass these options from commandline
    opts = chain.ChainTrainingOptions()
    opts.l2_regularize = 5e-4
    opts.leaky_hmm_coefficient = 0.1

    den_graph = chain.DenominatorGraph(fst=den_fst, num_pdfs=args.output_dim)

    model = get_chain_model(feat_dim=args.feat_dim,
                            output_dim=args.output_dim,
                            lda_mat_filename=args.lda_mat_filename,
                            hidden_dim=args.hidden_dim,
                            kernel_size_list=args.kernel_size_list,
                            stride_list=args.stride_list)

    start_epoch = 0
    num_epochs = args.num_epochs
    learning_rate = args.learning_rate
    best_objf = -100000

    if args.checkpoint:
        start_epoch, learning_rate, best_objf = load_checkpoint(
            args.checkpoint, model)
        logging.info(
            'loaded from checkpoint: start epoch {start_epoch}, '
            'learning rate {learning_rate}, best objf {best_objf}'.format(
                start_epoch=start_epoch,
                learning_rate=learning_rate,
                best_objf=best_objf))

    model.to(device)

    dataloader = get_egs_dataloader(egs_dir=args.cegs_dir,
                                    egs_left_context=args.egs_left_context,
                                    egs_right_context=args.egs_right_context)

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=args.l2_regularize)

    scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4, 5], gamma=0.5)
    criterion = KaldiChainObjfFunction.apply

    tf_writer = SummaryWriter(log_dir='{}/tensorboard'.format(args.dir))

    best_epoch = start_epoch
    best_model_path = os.path.join(args.dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(args.dir, 'best-epoch-info')
    try:
        for epoch in range(start_epoch, args.num_epochs):
            learning_rate = scheduler.get_lr()[0]
            logging.info('epoch {}, learning rate {}'.format(
                epoch, learning_rate))
            tf_writer.add_scalar('learning_rate', learning_rate, epoch)

            objf = train_one_epoch(dataloader=dataloader,
                                   model=model,
                                   device=device,
                                   optimizer=optimizer,
                                   criterion=criterion,
                                   current_epoch=epoch,
                                   opts=opts,
                                   den_graph=den_graph,
                                   tf_writer=tf_writer)
            scheduler.step()

            if best_objf is None:
                best_objf = objf
                best_epoch = epoch

            # the higher, the better
            if objf > best_objf:
                best_objf = objf
                best_epoch = epoch
                save_checkpoint(filename=best_model_path,
                                model=model,
                                epoch=epoch,
                                learning_rate=learning_rate,
                                objf=objf)
                save_training_info(filename=best_epoch_info_filename,
                                   model_path=best_model_path,
                                   current_epoch=epoch,
                                   learning_rate=learning_rate,
                                   objf=best_objf,
                                   best_objf=best_objf,
                                   best_epoch=best_epoch)

            # we always save the model for every epoch
            model_path = os.path.join(args.dir, 'epoch-{}.pt'.format(epoch))
            save_checkpoint(filename=model_path,
                            model=model,
                            epoch=epoch,
                            learning_rate=learning_rate,
                            objf=objf)

            epoch_info_filename = os.path.join(args.dir,
                                               'epoch-{}-info'.format(epoch))
            save_training_info(filename=epoch_info_filename,
                               model_path=model_path,
                               current_epoch=epoch,
                               learning_rate=learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               best_epoch=best_epoch)

    except KeyboardInterrupt:
        # save the model when ctrl-c is pressed
        model_path = os.path.join(args.dir,
                                  'epoch-{}-interrupted.pt'.format(epoch))
        # use a very small objf for interrupted model
        objf = -100000
        save_checkpoint(model_path,
                        model=model,
                        epoch=epoch,
                        learning_rate=learning_rate,
                        objf=objf)

        epoch_info_filename = os.path.join(
            args.dir, 'epoch-{}-interrupted-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           best_epoch=best_epoch)

    tf_writer.close()
    logging.warning('Done')
Example #5
0
def main():
    global best_acc

    args.out = args.dataset + '@N_' + str(args.num_max) + '_r_'
    if args.imb_ratio_l == args.imb_ratio_u:
        args.out += str(args.imb_ratio_l) + '_' + args.semi_method
    else:
        args.out += str(args.imb_ratio_l) + '_' + str(
            args.imb_ratio_u) + '_' + args.semi_method

    if args.darp:
        args.out += '_darp_alpha' + str(args.alpha) + '_iterT' + str(
            args.iter_T)

    if not os.path.isdir(args.out):
        mkdir_p(args.out)

    # Data
    N_SAMPLES_PER_CLASS = make_imb_data(args.num_max, args.num_class,
                                        args.imb_ratio_l)
    U_SAMPLES_PER_CLASS = make_imb_data(args.ratio * args.num_max,
                                        args.num_class, args.imb_ratio_u)
    N_SAMPLES_PER_CLASS_T = torch.Tensor(N_SAMPLES_PER_CLASS)

    print(args.out)

    if args.dataset == 'cifar10':
        print(f'==> Preparing imbalanced CIFAR-10')
        train_labeled_set, train_unlabeled_set, test_set = get_cifar10(
            '/home/jaehyung/data', N_SAMPLES_PER_CLASS, U_SAMPLES_PER_CLASS,
            args.out)
    elif args.dataset == 'stl10':
        print(f'==> Preparing imbalanced STL-10')
        train_labeled_set, train_unlabeled_set, test_set = get_stl10(
            '/home/jaehyung/data', N_SAMPLES_PER_CLASS, args.out)
    elif args.dataset == 'cifar100':
        print(f'==> Preparing imbalanced CIFAR-100')
        train_labeled_set, train_unlabeled_set, test_set = get_cifar100(
            '/home/jaehyung/data', N_SAMPLES_PER_CLASS, U_SAMPLES_PER_CLASS,
            args.out)
    labeled_trainloader = data.DataLoader(train_labeled_set,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=4,
                                          drop_last=True)
    unlabeled_trainloader = data.DataLoader(train_unlabeled_set,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=4,
                                            drop_last=True)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=4)

    # Model
    print("==> creating WRN-28-2")

    def create_model(ema=False):
        model = models.WRN(2, args.num_class)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model

    model = create_model()
    ema_model = create_model(ema=True)

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    train_criterion = SemiLoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    ema_optimizer = WeightEMA(model,
                              ema_model,
                              lr=args.lr,
                              alpha=args.ema_decay)
    start_epoch = 0

    # Resume
    title = 'Imbalanced' + '-' + args.dataset + '-' + args.semi_method
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.out, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
        logger.set_names([
            'Train Loss', 'Train Loss X', 'Train Loss U', 'Test Loss',
            'Test Acc.', 'Test GM.'
        ])

    test_accs = []
    test_gms = []

    # Default values for MixMatch and DARP
    emp_distb_u = torch.ones(args.num_class) / args.num_class
    pseudo_orig = torch.ones(len(train_unlabeled_set.data),
                             args.num_class) / args.num_class
    pseudo_refine = torch.ones(len(train_unlabeled_set.data),
                               args.num_class) / args.num_class

    # Main function
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        # Use the estimated distribution of unlabeled data
        if args.est:
            if args.dataset == 'cifar10':
                est_name = './estimation/cifar10@N_1500_r_{}_{}_estim.npy'.format(
                    args.imb_ratio_l, args.imb_ratio_u)
            else:
                est_name = './estimation/stl10@N_450_r_{}_estim.npy'.format(
                    args.imb_ratio_l)
            est_disb = np.load(est_name)
            target_disb = len(train_unlabeled_set.data) * torch.Tensor(
                est_disb) / np.sum(est_disb)
        # Use the inferred distribution with labeled data
        else:
            target_disb = N_SAMPLES_PER_CLASS_T * len(
                train_unlabeled_set.data) / sum(N_SAMPLES_PER_CLASS)

        train_loss, train_loss_x, train_loss_u, emp_distb_u, pseudo_orig, pseudo_refine = trains(
            args, labeled_trainloader, unlabeled_trainloader, model, optimizer,
            ema_optimizer, train_criterion, epoch, use_cuda, target_disb,
            emp_distb_u, pseudo_orig, pseudo_refine)

        # Evaluation part
        test_loss, test_acc, test_cls, test_gm = validate(
            test_loader,
            ema_model,
            criterion,
            use_cuda,
            mode='Test Stats',
            num_class=args.num_class)

        # Append logger file
        logger.append([
            train_loss, train_loss_x, train_loss_u, test_loss, test_acc,
            test_gm
        ])

        # Save models
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'ema_state_dict': ema_model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, epoch + 1, args.out)
        test_accs.append(test_acc)
        test_gms.append(test_gm)

    logger.close()

    # Print the final results
    print('Mean bAcc:')
    print(np.mean(test_accs[-20:]))

    print('Mean GM:')
    print(np.mean(test_gms[-20:]))

    print('Name of saved folder:')
    print(args.out)
Example #6
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    print('rank: {} / {}'.format(args.rank, args.world_size))
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)
    # init the args
    global best_pred, acclist_train, acclist_val

    if args.gpu == 0:
        print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    transform_train, transform_val = get_transform(
            args.dataset, args.base_size, args.crop_size, args.rand_aug)
    trainset = get_dataset(args.dataset, root=os.path.expanduser('~/.torch/data'),
                           transform=transform_train, train=True, download=True)
    valset = get_dataset(args.dataset, root=os.path.expanduser('~/.torch/data'),
                         transform=transform_val, train=False, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True,
        sampler=train_sampler)

    val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=False)
    val_loader = torch.utils.data.DataLoader(
        valset, batch_size=args.test_batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True,
        sampler=val_sampler)

    model_kwargs = {}
    if args.pretrained:
        model_kwargs['pretrained'] = True

    if args.final_drop > 0.0:
        model_kwargs['final_drop'] = args.final_drop

    if args.dropblock_prob > 0.0:
        model_kwargs['dropblock_prob'] = args.dropblock_prob

    if args.last_gamma:
        model_kwargs['last_gamma'] = True

    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg
    model = models.__dict__[args.model](**model_kwargs)

    if args.dropblock_prob > 0.0:
        from functools import partial
        from models import reset_dropblock
        nr_iters = (args.epochs - args.warmup_epochs) * len(train_loader)
        apply_drop_prob = partial(reset_dropblock, args.warmup_epochs*len(train_loader),
                                  nr_iters, 0.0, args.dropblock_prob)
        model.apply(apply_drop_prob)
    if args.gpu == 0:
        print(model)

    if args.mixup > 0:
        train_loader = MixUpWrapper(args.mixup, 1000, train_loader, args.gpu)
        criterion = NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        criterion = LabelSmoothing(args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    model.cuda(args.gpu)
    criterion.cuda(args.gpu)
    model = DistributedDataParallel(model, device_ids=[args.gpu])
    
    if args.no_bn_wd:
        parameters = model.named_parameters()
        param_dict = {}
        for k, v in parameters:
            param_dict[k] = v
        bn_params = [v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)]
        rest_params = [v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)]
        if args.gpu == 0:
            print(" Weight decay NOT applied to BN parameters ")
            print(f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}')
        optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay': 0 },
                                     {'params': rest_params, 'weight_decay': args.weight_decay}],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    # check point
    if args.resume is not None:
        if os.path.isfile(args.resume):
            if args.gpu == 0:
                print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1 if args.start_epoch == 0 else args.start_epoch
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if args.gpu == 0:
                print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
        else:
            raise RuntimeError ("=> no resume checkpoint found at '{}'".\
                format(args.resume))
    scheduler = LR_Scheduler(args.lr_scheduler,
                             base_lr=args.lr,
                             num_epochs=args.epochs,
                             iters_per_epoch=len(train_loader),
                             warmup_epochs=args.warmup_epochs)

    def train(epoch):
        train_sampler.set_epoch(epoch)
        model.train()
        losses = AverageMeter()
        top1 = AverageMeter()
        global best_pred, acclist_train
        for batch_idx, (data, target) in enumerate(train_loader):
            scheduler(optimizer, batch_idx, epoch, best_pred)
            if not args.mixup:
                data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if not args.mixup:
                acc1 = accuracy(output, target, topk=(1,))
                top1.update(acc1[0], data.size(0))

            losses.update(loss.item(), data.size(0))
            if batch_idx % 100 == 0 and args.gpu == 0:
                if args.mixup:
                    print('Batch: %d| Loss: %.3f'%(batch_idx, losses.avg))
                else:
                    print('Batch: %d| Loss: %.3f | Top1: %.3f'%(batch_idx, losses.avg, top1.avg))

        acclist_train += [top1.avg]

    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        top5 = AverageMeter()
        global best_pred, acclist_train, acclist_val
        is_best = False
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            with torch.no_grad():
                output = model(data)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))

        # sum all
        sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count)

        if args.eval:
            if args.gpu == 0:
                top1_acc = sum(sum1) / sum(cnt1)
                top5_acc = sum(sum5) / sum(cnt5)
                print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))
            return

        if args.gpu == 0:
            top1_acc = sum(sum1) / sum(cnt1)
            top5_acc = sum(sum5) / sum(cnt5)
            print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))

            # save checkpoint
            acclist_val += [top1_acc]
            if top1_acc > best_pred:
                best_pred = top1_acc 
                is_best = True
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train':acclist_train,
                'acclist_val':acclist_val,
                }, args=args, is_best=is_best)

    if args.export:
        if args.gpu == 0:
            torch.save(model.module.state_dict(), args.export + '.pth')
        return

    if args.eval:
        validate(args.start_epoch)
        return

    for epoch in range(args.start_epoch, args.epochs):
        tic = time.time()
        train(epoch)
        if epoch % 10 == 0:# or epoch == args.epochs-1:
            validate(epoch)
        elapsed = time.time() - tic
        if args.gpu == 0:
            print(f'Epoch: {epoch}, Time cost: {elapsed}')

    if args.gpu == 0:
        save_checkpoint({
            'epoch': args.epochs-1,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_pred': best_pred,
            'acclist_train':acclist_train,
            'acclist_val':acclist_val,
            }, args=args, is_best=False)
Example #7
0
def main():
    # load L, G, symbol_table
    lang_dir = 'data/lang_nosp'
    with open(lang_dir + '/L.fst.txt') as f:
        L = k2.Fsa.from_openfst(f.read(), acceptor=False)

    with open(lang_dir + '/G.fsa.txt') as f:
        G = k2.Fsa.from_openfst(f.read(), acceptor=True)

    with open(lang_dir + '/words.txt') as f:
        symbol_table = k2.SymbolTable.from_str(f.read())

    L = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    graph = k2.intersect(L, G)
    graph = k2.arc_sort(graph)

    # load dataset
    feature_dir = 'exp/data1'
    cuts_train = CutSet.from_json(feature_dir +
                                  '/cuts_train-clean-100.json.gz')

    cuts_dev = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz')

    train = K2SpeechRecognitionIterableDataset(cuts_train, shuffle=True)
    validate = K2SpeechRecognitionIterableDataset(cuts_dev, shuffle=False)
    train_dl = torch.utils.data.DataLoader(train,
                                           batch_size=None,
                                           num_workers=1)
    valid_dl = torch.utils.data.DataLoader(validate,
                                           batch_size=None,
                                           num_workers=1)

    dir = 'exp'
    setup_logger('{}/log/log-train'.format(dir))

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    device_id = 0
    device = torch.device('cuda', device_id)
    model = Wav2Letter(num_classes=364, input_type='mfcc', num_features=40)
    model.to(device)

    learning_rate = 0.001
    start_epoch = 0
    num_epochs = 10
    best_objf = 100000
    best_epoch = start_epoch
    best_model_path = os.path.join(dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(dir, 'best-epoch-info')

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=5e-4)
    # optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

    for epoch in range(start_epoch, num_epochs):
        curr_learning_rate = learning_rate * pow(0.4, epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = curr_learning_rate

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf = train_one_epoch(dataloader=train_dl,
                               valid_dataloader=valid_dl,
                               model=model,
                               device=device,
                               graph=graph,
                               symbols=symbol_table,
                               optimizer=optimizer,
                               current_epoch=epoch,
                               num_epochs=num_epochs)
        if objf < best_objf:
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=best_objf,
                               best_objf=best_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf)
        epoch_info_filename = os.path.join(dir, 'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Example #8
0
def train(args):
    logging.basicConfig(level=logging.INFO,
                        filename=args.exp_name + '/train.log',
                        filemode='w')
    # train for some number of epochs
    writer = SummaryWriter(log_dir=args.exp_name + '/summary')

    # build the Bert model
    model = build_model(args)

    # build the optimizier
    optimizer = build_optim(args, model)

    # build the criterion
    criterion = build_criterion()

    # load data
    all_input_ids, labels = load_data()

    # create dataset
    train_dataset, val_dataset = create_dataset(all_input_ids, labels)

    # create the data loader
    train_loader, test_loader = create_dataloader(args, train_dataset,
                                                  val_dataset)

    for epoch in range(args.epoch):
        total_loss, total_val_loss = 0, 0
        total_eval_accuracy = 0
        total_train_accuracy = 0
        # initialize hidden state
        h = model.init_hidden(args.batch_size, USE_CUDA)

        model.train()
        # batch loop
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            h = tuple([each.data for each in h])
            model.zero_grad()
            # intputs [b, 512]
            output = model(inputs, h)  # output [b,1]
            # print(output.shape)
            loss = criterion(output, labels.long())
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

            output = output.detach().cpu().numpy()
            labels = labels.cpu().numpy()
            total_train_accuracy += flat_accuracy(output, labels)

        model.eval()
        with torch.no_grad():
            val_h = model.init_hidden(args.batch_size)
            for inputs, labels in tqdm(test_loader):
                val_h = tuple([each.data for each in val_h])
                inputs, labels = inputs.to(device), labels.to(device)

                output = model(inputs, val_h)
                val_loss = criterion(output, labels.long())
                total_val_loss += val_loss
                # prediction
                pred = output.argmax(dim=1)

                output = output.detach().cpu().numpy()
                labels = labels.cpu().numpy()
                total_eval_accuracy += flat_accuracy(output, labels)

            model.train()

        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(test_loader)
        avg_val_accuracy = total_eval_accuracy / len(test_loader)
        avg_train_accuracy = total_train_accuracy / len(train_loader)
        logging.info(f'Epoch     : {epoch+1}'
                     f'Train loss     : {avg_train_loss}'
                     f'Train Accuracy: {avg_train_accuracy:.2f}'
                     f'Validation loss: {avg_val_loss}'
                     f'Validation Accuracy: {avg_val_accuracy:.2f}')
        save_checkpoint(args, model, epoch, avg_val_accuracy, avg_val_loss)
Example #9
0
def main():
    # load L, G, symbol_table
    lang_dir = 'data/lang_nosp'
    with open(lang_dir + '/words.txt') as f:
        symbol_table = k2.SymbolTable.from_str(f.read())

    ## This commented code created LG.  We don't need that there.
    ## There were problems with disambiguation symbols; the G has
    ## disambiguation symbols which L.fst doesn't support.
    # if not os.path.exists(lang_dir + '/LG.pt'):
    #     print("Loading L.fst.txt")
    #     with open(lang_dir + '/L.fst.txt') as f:
    #         L = k2.Fsa.from_openfst(f.read(), acceptor=False)
    #     print("Loading G.fsa.txt")
    #     with open(lang_dir + '/G.fsa.txt') as f:
    #         G = k2.Fsa.from_openfst(f.read(), acceptor=True)
    #     print("Arc-sorting L...")
    #     L = k2.arc_sort(L.invert_())
    #     G = k2.arc_sort(G)
    #     print(k2.is_arc_sorted(k2.get_properties(L)))
    #     print(k2.is_arc_sorted(k2.get_properties(G)))
    #     print("Intersecting L and G")
    #     graph = k2.intersect(L, G)
    #     graph = k2.arc_sort(graph)
    #     print(k2.is_arc_sorted(k2.get_properties(graph)))
    #     torch.save(graph.as_dict(), lang_dir + '/LG.pt')
    # else:
    #     d = torch.load(lang_dir + '/LG.pt')
    #     print("Loading pre-prepared LG")
    #     graph = k2.Fsa.from_dict(d)

    print("Loading L.fst.txt")
    with open(lang_dir + '/L.fst.txt') as f:
        L = k2.Fsa.from_openfst(f.read(), acceptor=False)
    L = k2.arc_sort(L.invert_())

    # load dataset
    feature_dir = 'exp/data1'
    print("About to get train cuts")
    #cuts_train = CutSet.from_json(feature_dir +
    #                              '/cuts_train-clean-100.json.gz')
    cuts_train = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz')
    print("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz')

    print("About to create train dataset")
    train = K2SpeechRecognitionIterableDataset(cuts_train,
                                               max_frames=1000,
                                               shuffle=True)
    print("About to create dev dataset")
    validate = K2SpeechRecognitionIterableDataset(cuts_dev,
                                                  max_frames=1000,
                                                  shuffle=False)
    print("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(train,
                                           batch_size=None,
                                           num_workers=1)
    print("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           batch_size=None,
                                           num_workers=1)

    exp_dir = 'exp'
    setup_logger('{}/log/log-train'.format(exp_dir))

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)
    print("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)
    model = Wav2Letter(num_classes=364, input_type='mfcc', num_features=40)
    model.to(device)

    learning_rate = 0.001
    start_epoch = 0
    num_epochs = 10
    best_objf = 100000
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=5e-4)
    # optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

    for epoch in range(start_epoch, num_epochs):
        curr_learning_rate = learning_rate * pow(0.4, epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = curr_learning_rate

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf = train_one_epoch(dataloader=train_dl,
                               valid_dataloader=valid_dl,
                               model=model,
                               device=device,
                               L=L,
                               symbols=symbol_table,
                               optimizer=optimizer,
                               current_epoch=epoch,
                               num_epochs=num_epochs)
        if objf < best_objf:
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=best_objf,
                               best_objf=best_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Example #10
0
def main():
    global args, model

    args = parser.parse_args()
    print(args)

    if args.gpu and not torch.cuda.is_available():
        raise Exception("No GPU found!")

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    torch.manual_seed(2020)

    cudnn.benchmark = True
    device = torch.device(('cuda:' + args.gpu_id) if args.gpu else 'cpu')

    model = Grad_concat.GRAD(feats=args.feats,
                             basic_conv=args.basic_conv,
                             tail_conv=args.tail_conv).to(device)
    optimizer = optim.Adam(filter(lambda x: x.requires_grad,
                                  model.parameters()),
                           lr=args.lr)
    criterion = nn.L1Loss()

    if args.continue_train:
        checkpoint_file = torch.load(args.checkpoint_file)
        model.load_state_dict(checkpoint_file['model'])
        optimizer.load_state_dict(checkpoint_file['optimizer'])
        start_epoch = checkpoint_file['epoch']
        best_epoch = checkpoint_file['best_epoch']
        best_psnr = checkpoint_file['best_psnr']
        print("continue train {}.".format(start_epoch))
    else:
        start_epoch = 0
        best_epoch = 0
        best_psnr = 0

    print("Loading dataset ...")
    train_dataset = DIV2K(args, train=True)
    valid_dataset = DIV2K(args, train=False)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=1)
    valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1)

    checkpoint_name = "latest.pth"
    is_best = False
    for epoch in range(start_epoch + 1, args.epochs + 1):
        lr = adjust_lr(optimizer, args.lr, epoch, args.decay_step,
                       args.decay_gamma)
        print("[epoch:{}/{}]".format(epoch, args.epochs))

        train(train_dataset, train_dataloader, model, criterion, optimizer,
              device)
        if epoch >= 90:
            valid_psnr = valid(valid_dataset, valid_dataloader, model,
                               criterion, device)

            is_best = valid_psnr > best_psnr
            if is_best:
                best_psnr = valid_psnr
                best_epoch = epoch

            print("learning rate: {}".format(lr))
            print("PSNR: {:.4f}".format(valid_psnr))
            print("best PSNR: {:4f} in epoch: {}".format(
                best_psnr, best_epoch))

        save_checkpoint(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_psnr': best_psnr,
                'best_epoch': best_epoch,
            }, os.path.join(args.checkpoint_dir, checkpoint_name), is_best)