예제 #1
0
def main(args):
    model = build_model(args, log)
    recorder = recorders.Records(records=None)

    train_loader, val_loader = custom_dataloader(args, log)

    for epoch in range(args.start_epoch, args.epochs+1):
        model.update_learning_rate()
        recorder.insert_record('train', 'lr', epoch, model.get_learning_rate())

        train_utils.train(args, log, train_loader, model, epoch, recorder)
        if epoch % args.save_intv == 0: 
            model.save_checkpoint(epoch, recorder.records)
        log.plot_curves(recorder, 'train')

        if epoch % args.val_intv == 0:
            test_utils.test(args, log, 'val', val_loader, model, epoch, recorder)
            log.plot_curves(recorder, 'val')
                          weight_decay=1e-4)
    # loss
    criterion = CrossEntropyLoss().cuda()
    exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                 step_size=20,
                                                 gamma=0.1)
    best_precision = 0
    lowest_loss = 10000

    # training
    for epoch in range(opt.start_epoch, opt.epochs):
        # training
        acc_train, loss_train = train(dataloader['train1'],
                                      model,
                                      criterion,
                                      optimizer,
                                      exp_lr_scheduler,
                                      epoch,
                                      print_interval=opt.print_interval,
                                      filename=opt.checkpoint_dir)
        # Record the training accuracy and loss of each epoch in the log file
        with open(opt.checkpoint_dir + 'record.txt', 'a') as acc_file:
            acc_file.write(
                'Epoch: %2d, train_Precision: %.8f, train_Loss: %.8f\n' %
                (epoch, acc_train, loss_train))
        # validate
        precision, avg_loss = validate(dataloader['val'],
                                       model,
                                       criterion,
                                       print_interval=opt.print_interval,
                                       filename=opt.checkpoint_dir)
        exp_lr_scheduler.step()
예제 #3
0
from utils.parse_args import parse_args
from utils.train_utils import create_optimizer, train
from data_manager.shapenet import ShapenetDataProcess
from data_manager.data_process import kill_data_processes

epoch = 200

args = parse_args()
# args.model = PointNetFCAE_create_model(args)
args.model = MLP()

data_processes = []
data_queue = Queue(1)

for i in range(args.nworkers):
    data_processes.append(
        ShapenetDataProcess(data_queue, args, split='train', repeat=False))
    data_processes[-1].start()

# args.error = torch.nn.MSELoss()
# args.error = ChamferLoss()
args.optimizer = create_optimizer(args, args.model)

i = 0

while i != epoch:
    train(args, data_queue, data_processes, i)
    i += 1

kill_data_processes(data_queue, data_processes)
train_utils.save_checkpoint(
    args.dir,
    start_epoch - 1,
    model_state=model.state_dict(),
    optimizer_state=optimizer.state_dict()
)

test_res = {'loss': None, 'accuracy': None, 'nll': None}
for epoch in range(start_epoch, args.epochs + 1):
    time_ep = time.time()

    lr = learning_rate_schedule(args.lr, epoch, args.epochs)
    train_utils.adjust_learning_rate(optimizer, lr)

    train_res = train_utils.train(loaders['train'], model, optimizer, criterion, regularizer, cuda=args.cuda)
    test_res = train_utils.test(loaders['test'], model, criterion, regularizer, cuda=args.cuda)

    if epoch % args.save_freq == 0:
        train_utils.save_checkpoint(
            args.dir,
            epoch,
            model_state=model.state_dict(),
            optimizer_state=optimizer.state_dict()
        )

    time_ep = time.time() - time_ep
    values = [epoch, lr, train_res['loss'], train_res['accuracy'], test_res['nll'],
              test_res['accuracy'], time_ep]

    table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='9.4f')
예제 #5
0
#optimizer = optim.RMSprop(model.parameters(), lr=LR, weight_decay=w_decay)#, momentum=0.95
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=Epoch_step_size, gamma=Gamma)
'''lcc + grad'''
criterion = {
    'lcc': loss.LCC(args.win).cuda(),
    'mse': torch.nn.MSELoss().cuda(),
    'lambda': args.lamb
}

losses = []
for epoch in range(1, args.epoch + 1):
    since = time.time()

    #    scheduler.step()# adjust lr
    ### Train ###
    trn_loss, trn_lcc, trn_mae = train_utils.train(model, train_loader,
                                                   optimizer, criterion, epoch)
    print(
        'Epoch {:d}\nTrain - Loss: {:.4f} | Lcc: {:.4f} | MAE: {:.4f}'.format(
            epoch, trn_loss, trn_lcc, trn_mae))
    time_elapsed = time.time() - since
    print('Train Time {:.0f}m {:.0f}s'.format(time_elapsed // 60,
                                              time_elapsed % 60))
    print(
        'Epoch {:d}\nTrain - Loss: {:.4f} | Lcc: {:.4f} | MAE: {:.4f}'.format(
            epoch, trn_loss, trn_lcc, trn_mae),
        file=f)
    print('Train Time {:.0f}m {:.0f}s'.format(time_elapsed // 60,
                                              time_elapsed % 60),
          file=f)

    ### Val ###
예제 #6
0
파일: train.py 프로젝트: bill4u/mergeNet
def main():
    global args, best_iou, iterations
    args = parser.parse_args()

    if args.tensorboard:
        from tensorboard_logger import configure
        print("Using tensorboard")
        configure("%s" % (args.dir))

    # model configurations
    num_classes = args.num_classes
    num_offsets = args.num_offsets
    if args.mode == 'offset':  # offset only
        num_classes = 0
    if args.mode == 'class':  # class only
        num_offsets = 0

    # model
    model = get_model(num_classes, num_offsets, args.arch, args.pretrain)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['model_state'])
            if 'offset' in checkpoint:  # class mode doesn't have offset
                offset_list = checkpoint['offset']
                print("offsets are: {}".format(offset_list))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise ValueError("=> no checkpoint found at '{}'".format(
                args.resume))

    # model distribution
    if args.gpu != -1:
        # DataParallel wrapper (synchronzied batchnorm edition)
        if len(args.gpu) > 1:
            model = DataParallelWithCallback(model, device_ids=args.gpu)
        model.cuda()

    # dataset
    if args.mode == 'all':
        offset_list = generate_offsets(80 / args.scale, args.num_offsets)
        trainset = AllDataset(args.train_img,
                              args.train_ann,
                              num_classes,
                              offset_list,
                              scale=args.scale,
                              crop=args.crop,
                              crop_size=(args.crop_size, args.crop_size),
                              limits=args.limits)
        valset = AllDataset(args.val_img,
                            args.val_ann,
                            num_classes,
                            offset_list,
                            scale=args.scale,
                            limits=args.limits)
        class_nms = trainset.catNms
    elif args.mode == 'class':
        offset_list = None
        trainset = ClassDataset(args.train_img,
                                args.train_ann,
                                scale=args.scale,
                                crop=args.crop,
                                crop_size=(args.crop_size, args.crop_size),
                                limits=args.limits)
        valset = ClassDataset(args.val_img,
                              args.val_ann,
                              scale=args.scale,
                              limits=args.limits)
        class_nms = trainset.catNms
    elif args.mode == 'offset':
        offset_list = generate_offsets(80 / args.scale, args.num_offsets)
        print("offsets are: {}".format(offset_list))
        trainset = OffsetDataset(args.train_img,
                                 args.train_ann,
                                 offset_list,
                                 scale=args.scale,
                                 crop=args.crop,
                                 crop_size=args.crop_size,
                                 limits=args.limits)
        valset = OffsetDataset(args.val_img,
                               args.val_ann,
                               offset_list,
                               scale=args.scale,
                               limits=args.limits)
        class_nms = None

    trainloader = torch.utils.data.DataLoader(trainset,
                                              num_workers=4,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    valloader = torch.utils.data.DataLoader(valset,
                                            num_workers=4,
                                            batch_size=4)
    num_train = len(trainset)
    num_val = len(valset)
    print('Training samples: {0} \n'
          'Validation samples: {1}'.format(num_train, num_val))

    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)
    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    # # define loss functions

    criterion_ofs = torch.nn.BCEWithLogitsLoss().cuda()

    if args.mode == 'all':
        criterion_cls = torch.nn.BCEWithLogitsLoss().cuda()
        criterion_ofs = torch.nn.BCEWithLogitsLoss().cuda()
    elif args.mode == 'class':
        criterion_cls = torch.nn.BCEWithLogitsLoss().cuda()
        criterion_ofs = None
    elif args.mode == 'offset':
        criterion_cls = None
        if args.loss == 'bce':
            print('Using Binary Cross Entropy Loss')
            criterion_ofs = torch.nn.BCEWithLogitsLoss().cuda()
        elif args.loss == 'mbce':
            print('Using Weighted Multiclass BCE Loss')
            criterion_ofs = MultiBCEWithLogitsLoss().cuda()
        elif args.loss == 'dice':
            print('Using Soft Dice Loss (0 mode)')
            criterion_ofs = SoftDiceLoss(mode='0').cuda()
        else:
            print('Using Cross Entropy Loss')
            criterion_ofs = CrossEntropyLossOneHot().cuda()

    # define learning rate scheduler
    if not args.milestones:
        milestones = [args.epochs]
    else:
        milestones = args.milestones
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=milestones,
                                         gamma=0.2,
                                         last_epoch=args.start_epoch - 1)

    # start iteration count
    iterations = args.start_epoch * int(len(trainset) / args.batch_size)

    # train
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        iterations = train(trainloader,
                           model,
                           optimizer,
                           args.batch_size,
                           epoch,
                           iterations,
                           criterion_cls=criterion_cls,
                           class_nms=class_nms,
                           criterion_ofs=criterion_ofs,
                           offset_list=offset_list,
                           print_freq=args.print_freq,
                           log_freq=args.log_freq,
                           tensorboard=args.tensorboard,
                           score=args.score,
                           alpha=args.alpha)
        val_iou = validate(valloader,
                           model,
                           args.batch_size,
                           epoch,
                           iterations,
                           criterion_cls=criterion_cls,
                           class_nms=class_nms,
                           criterion_ofs=criterion_ofs,
                           offset_list=offset_list,
                           print_freq=args.print_freq,
                           log_freq=args.log_freq,
                           tensorboard=args.tensorboard,
                           score=args.score,
                           alpha=args.alpha)
        # visualize some example outputs after each epoch
        if args.visual_freq > 0 and epoch % args.visual_freq == 0:
            outdir = '{}/imgs/{}'.format(args.dir, epoch)
            if not os.path.exists(outdir):
                os.makedirs(outdir)
            sample(model, valloader, outdir, num_classes, num_offsets)

        # save checkpoint
        is_best = val_iou > best_iou
        best_iou = max(val_iou, best_iou)
        if args.gpu != -1 and len(args.gpu) > 1:
            state_dict = {
                'epoch': epoch + 1,
                'model_state':
                model.module.state_dict(),  # remove 'module' in checkpoint
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict()
            }
        else:
            state_dict = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict()
            }
        if args.mode != 'class':
            state_dict['offset'] = offset_list
        save_checkpoint(args.dir, state_dict, is_best)

    print('Best validation mean iou: ', best_iou)
def main_worker(gpu, ngpus_per_node, args):
    # Retrieve config file
    p = create_config(args.config_env, args.config_exp)

    # Check gpu id
    args.gpu = gpu
    p['gpu'] = gpu
    if args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass
    else:
        sys.stdout = Logger(os.path.join(p['output_dir'], 'log_file.txt'))

    if args.dist_url == "env://" and args.rank == -1:
        args.rank = int(os.environ["RANK"])

    # For multiprocessing distributed training, rank needs to be the
    # global rank among all the processes
    args.rank = args.rank * ngpus_per_node + gpu
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)

    print('Python script is {}'.format(os.path.abspath(__file__)))
    print(colored(p, 'red'))

    # Get model
    print(colored('Retrieve model', 'blue'))
    model = ContrastiveModel(p)
    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)

    # Optimizer
    print(colored('Retrieve optimizer', 'blue'))
    optimizer = get_optimizer(p, model.parameters())
    print(optimizer)

    # Nvidia-apex
    if args.nvidia_apex:
        print(colored('Using mixed precision training', 'blue'))
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O2",
                                          keep_batchnorm_fp32=True,
                                          loss_scale="dynamic")
    else:
        amp = None

    # When using a single GPU per process and per
    # DistributedDataParallel, we need to divide the batch size
    # ourselves based on the total number of GPUs we have
    p['train_batch_size'] = int(p['train_batch_size'] / ngpus_per_node)
    p['num_workers'] = int(
        (p['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.gpu], find_unused_parameters=True)

    # CUDNN
    print(colored('Set CuDNN benchmark', 'blue'))
    torch.backends.cudnn.benchmark = True

    # Dataset
    print(colored('Retrieve dataset', 'blue'))

    # Transforms
    train_transform = get_train_transformations()
    print(train_transform)
    train_dataset = DatasetKeyQuery(
        get_train_dataset(p, transform=None),
        train_transform,
        downsample_sal=not p['model_kwargs']['upsample'])
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=p['train_batch_size'],
        shuffle=(train_sampler is None),
        num_workers=p['num_workers'],
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True,
        collate_fn=collate_custom)
    print(colored('Train samples %d' % (len(train_dataset)), 'yellow'))
    print(colored(train_dataset, 'yellow'))

    # Resume from checkpoint
    if os.path.exists(p['checkpoint']):
        print(
            colored('Restart from checkpoint {}'.format(p['checkpoint']),
                    'blue'))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(p['checkpoint'], map_location=loc)
        optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(checkpoint['model'])
        if args.nvidia_apex:
            amp.load_state_dict(checkpoint['amp'])
        start_epoch = checkpoint['epoch']

    else:
        print(
            colored('No checkpoint file at {}'.format(p['checkpoint']),
                    'blue'))
        start_epoch = 0
        model = model.cuda()

    # Main loop
    print(colored('Starting main loop', 'blue'))

    for epoch in range(start_epoch, p['epochs']):
        print(colored('Epoch %d/%d' % (epoch + 1, p['epochs']), 'yellow'))
        print(colored('-' * 10, 'yellow'))

        # Adjust lr
        lr = adjust_learning_rate(p, optimizer, epoch)
        print('Adjusted learning rate to {:.5f}'.format(lr))

        # Train
        print('Train ...')
        eval_train = train(p, train_dataloader, model, optimizer, epoch, amp)

        # Checkpoint
        if args.rank % ngpus_per_node == 0:
            print('Checkpoint ...')
            if args.nvidia_apex:
                torch.save(
                    {
                        'optimizer': optimizer.state_dict(),
                        'model': model.state_dict(),
                        'amp': amp.state_dict(),
                        'epoch': epoch + 1
                    }, p['checkpoint'])

            else:
                torch.save(
                    {
                        'optimizer': optimizer.state_dict(),
                        'model': model.state_dict(),
                        'epoch': epoch + 1
                    }, p['checkpoint'])
def main(train_file=os.path.join(Config.root_path, 'data/ranking/train.tsv'),
         dev_file=os.path.join(Config.root_path, 'data/ranking/dev.tsv'),
         model_path=Config.bert_model,
         epochs=10,
         batch_size=32,
         lr=2e-05,
         patience=3,
         max_grad_norm=10.0,
         checkpoint=None):
    logging.info(20 * "=" + " Preparing for training " + 20 * "=")
    bert_tokenizer = BertTokenizer.from_pretrained(Config.vocab_path,
                                                   do_lower_case=True)
    device = torch.device("cuda") if Config.is_cuda else torch.device("cpu")
    if not os.path.exists(os.path.dirname(model_path)):
        os.mkdir(os.path.dirname(model_path))
    logging.info("\t* Loading training data...")
    train_dataset = DataPrecessForSentence(bert_tokenizer=bert_tokenizer,
                                           file=train_file,
                                           max_char_len=Config.max_seq_len)
    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    logging.info("\t* Loading validation data...")
    dev_dataset = DataPrecessForSentence(bert_tokenizer=bert_tokenizer,
                                         file=dev_file,
                                         max_char_len=Config.max_seq_len)
    dev_dataloader = DataLoader(
        dataset=dev_dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    logging.info("\t* Building model...")
    model = BertModelTrain().to(device)

    # 待优化的参数
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.85,
                                                           patience=0)
    best_score = 0.0
    start_epoch = 1
    # Data for loss curves plot
    epochs_count = []
    train_losses = []
    valid_losses = []
    # Continuing training from a checkpoint if one was given as argument
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]
        logging.info(
            "\t* Training will continue on existing model from epoch {}...".
            format(start_epoch))
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        valid_losses = checkpoint["valid_losses"]
        # Compute loss and accuracy before starting (or resuming) training.
        _, valid_loss, valid_accuracy, auc = validate(model, dev_dataloader)
        logging.info(
            "\t* Validation loss before training: {:.4f}, accuracy: {:.4f}%, \
            auc: {:.4f}".format(valid_loss, (valid_accuracy * 100), auc))

    # -------------------- Training epochs ------------------- #
    logging.info("\n" + 20 * "=" +
                 "Training Bert model on device: {}".format(device) + 20 * "=")
    patience_counter = 0
    for i in range(start_epoch, epochs + 1):
        logging.info("* starting training epoch {}".format(i))
        train_time, train_loss, train_acc = train(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            epoch_number=i,
            max_gradient_norm=max_grad_norm)
        train_losses.append(train_loss)
        logging.info("-> Training time: {:.4f}s, loss = {:.4f}, \
            accuracy: {:.4f}%".format(train_time, train_loss,
                                      (train_acc * 100)))

        logging.info("* Validation for epoch {}:".format(i))
        val_time, val_loss, val_acc, score = validate(
            model=model, dataloader=dev_dataloader)
        valid_losses.append(val_loss)
        logging.info(
            "-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%, \
            auc: {:.4f}\n".format(val_time, val_loss, (val_acc * 100), score))
        scheduler.step(val_acc)
        # Early stopping on validation accuracy.
        if val_acc < best_score:
            patience_counter += 1
        else:
            best_score = val_acc
            patience_counter = 0
            torch.save(
                {
                    "epoch": i,
                    "model": model.state_dict(),
                    "best_score": best_score,
                    "epochs_count": epochs_count,
                    "train_losses": train_losses,
                    "valid_losses": valid_losses
                }, model_path)
        if patience_counter >= patience:
            logging.info(
                "-> Early stopping: patience limit reached, stopping...")
            break
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=0,
                                             collate_fn=lstm_collate,
                                             pin_memory=True)
test_iterator = torch.utils.data.DataLoader(test_loader,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=0,
                                            collate_fn=lstm_collate,
                                            pin_memory=True)

new_top1, top1 = 0.0, 0.0
isbest = False
for epoch in range(len(lr_range) * max_epochs):
    train(model_ft, optimizer, ce_loss, train_iterator, epoch, log_file,
          lr_scheduler)

    new_top1 = test(model_ft, ce_loss, test_iterator, epoch, "Test", log_file)
#    isbest = True if new_top1 >= top1 else False
#
#    if save_all_weights:
#        weight_file = os.path.join(output_dir, model_name + '_{:03d}.pth'.format(epoch))
#    else:
#        weight_file = os.path.join(output_dir, model_name + '_ckpt.pth')
#    print_and_save('Saving weights to {}'.format(weight_file), log_file)
#    torch.save({'epoch': epoch + 1,
#                'state_dict': model_ft.state_dict(),
#                'optimizer': optimizer.state_dict(),
#                'top1': new_top1}, weight_file)
#    if isbest:
#        best = os.path.join(output_dir, model_name+'_best.pth')
예제 #10
0
def main():
    global args, best_iou, iterations
    args = parser.parse_args()

    if args.tensorboard:
        from tensorboard_logger import configure
        print("Using tensorboard")
        configure("%s" % (args.dir))

    offset_list = generate_offsets(args.num_offsets)

    # model configurations
    num_classes = args.num_classes
    num_offsets = args.num_offsets

    # model
    model = get_model(num_classes, num_offsets, args.arch, args.pretrain)
    model = model.cuda()

    # dataset
    trainset = COCODataset(args.train_img,
                           args.train_ann,
                           num_classes,
                           offset_list,
                           scale=args.scale,
                           size=(args.train_image_size, args.train_image_size),
                           limits=args.limits,
                           crop=args.crop)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              num_workers=4,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    valset = COCODataset(args.val_img,
                         args.val_ann,
                         num_classes,
                         offset_list,
                         scale=args.scale,
                         limits=args.limits)
    valloader = torch.utils.data.DataLoader(valset,
                                            num_workers=4,
                                            batch_size=4)
    num_train = len(trainset)
    num_val = len(valset)
    print('Training samples: {0} \n'
          'Validation samples: {1}'.format(num_train, num_val))

    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            offset_list = checkpoint['offset']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise ValueError("=> no checkpoint found at '{}'".format(
                args.resume))
    print("offsets are: {}".format(offset_list))

    # define loss functions
    if args.loss == 'bce':
        print('Using Binary Cross Entropy Loss')
        criterion_cls = torch.nn.BCEWithLogitsLoss().cuda()
    elif args.loss == 'mbce':
        print('Using Weighted Multiclass BCE Loss')
        criterion_cls = MultiBCEWithLogitsLoss().cuda()
    elif args.loss == 'dice':
        print('Using Soft Dice Loss')
        criterion_cls = SoftDiceLoss().cuda()
    else:
        print('Using Cross Entropy Loss')
        criterion_cls = CrossEntropyLossOneHot().cuda()

    criterion_ofs = torch.nn.BCEWithLogitsLoss().cuda()

    # define learning rate scheduler
    if not args.milestones:
        milestones = [args.epochs]
    else:
        milestones = args.milestones
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=milestones,
                                         gamma=0.1,
                                         last_epoch=args.start_epoch - 1)

    # start iteration count
    iterations = args.start_epoch * int(len(trainset) / args.batch_size)

    # define score metrics
    score_metrics_train = runningScore(num_classes, trainset.catNms)
    score_metrics = runningScore(num_classes, valset.catNms)
    offset_metrics_train = offsetIoU(offset_list)
    offset_metrics_val = offsetIoU(offset_list)

    # train
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        iterations = train(trainloader,
                           model,
                           criterion_cls,
                           criterion_ofs,
                           optimizer,
                           num_classes,
                           args.batch_size,
                           epoch,
                           iterations,
                           print_freq=args.print_freq,
                           log_freq=args.log_freq,
                           tensorboard=args.tensorboard,
                           score_metrics=score_metrics_train,
                           offset_metrics=offset_metrics_train,
                           alpha=args.alpha)
        val_iou = validate(valloader,
                           model,
                           criterion_cls,
                           criterion_ofs,
                           num_classes,
                           args.batch_size,
                           epoch,
                           iterations,
                           print_freq=args.print_freq,
                           log_freq=args.log_freq,
                           tensorboard=args.tensorboard,
                           score_metrics=score_metrics,
                           offset_metrics=offset_metrics_val,
                           alpha=args.alpha)
        # visualize some example outputs after each epoch
        if args.visualize:
            outdir = '{}/imgs/{}'.format(args.dir, epoch + 1)
            if not os.path.exists(outdir):
                os.makedirs(outdir)
            sample(num_classes, num_offsets, model, valloader, outdir)

        is_best = val_iou > best_iou
        best_iou = max(val_iou, best_iou)
        save_checkpoint(
            args.dir, {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict(),
                'offset': offset_list,
            }, is_best)
    print('Best validation mean iou: ', best_iou)
예제 #11
0
파일: qa.py 프로젝트: k-weng/qa-retrieval
def main():
    global args, best_mrr, best_auc
    args = parser.parse_args()
    cuda_available = torch.cuda.is_available()
    print args

    corpus_file = 'data/askubuntu/text_tokenized.txt.gz'
    dataset = UbuntuDataset(corpus_file)
    corpus = dataset.get_corpus()

    if args.embedding == 'askubuntu':
        embedding_file = 'data/askubuntu/vector/vectors_pruned.200.txt.gz'
    else:
        embedding_file = 'data/glove/glove.pruned.txt.gz'

    embedding_iter = Embedding.iterator(embedding_file)
    embedding = Embedding(args.embed, embedding_iter)
    print 'Embeddings loaded.'

    corpus_ids = embedding.corpus_to_ids(corpus)
    padding_id = embedding.vocab_ids['<padding>']

    train_file = 'data/askubuntu/train_random.txt'
    train_data = dataset.read_annotations(train_file)

    dev_file = 'data/askubuntu/dev.txt'
    dev_data = dataset.read_annotations(dev_file, max_neg=-1)
    dev_batches = batch_utils.generate_eval_batches(corpus_ids, dev_data,
                                                    padding_id)

    assert args.model in ['lstm', 'cnn']
    if args.model == 'lstm':
        model = LSTM(args.embed, args.hidden)
    else:
        model = CNN(args.embed, args.hidden)

    print model
    print 'Parameters: {}'.format(params(model))

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    criterion = nn.MultiMarginLoss(margin=args.margin)

    if cuda_available:
        criterion = criterion.cuda()

    if args.load:
        if os.path.isfile(args.load):
            print 'Loading checkpoint.'
            checkpoint = torch.load(args.load)
            args.start_epoch = checkpoint['epoch']
            best_mrr = checkpoint.get('best_mrr', -1)
            best_auc = checkpoint.get('best_auc', -1)
            model.load_state_dict(checkpoint['state_dict'])

            print 'Loaded checkpoint at epoch {}.'.format(checkpoint['epoch'])
        else:
            print 'No checkpoint found here.'

    if args.eval:
        test_file = 'data/askubuntu/test.txt'
        test_data = dataset.read_annotations(test_file, max_neg=-1)
        test_batches = batch_utils.generate_eval_batches(
            corpus_ids, test_data, padding_id)

        print 'Evaluating on dev set.'
        train_utils.evaluate_metrics(args, model, embedding, dev_batches,
                                     padding_id)

        print 'Evaluating on test set.'
        train_utils.evaluate_metrics(args, model, embedding, test_batches,
                                     padding_id)
        return

    if args.android:
        android_file = 'data/android/corpus.tsv.gz'
        android_dataset = AndroidDataset(android_file)
        android_ids = embedding.corpus_to_ids(android_dataset.get_corpus())

        dev_pos_file = 'data/android/dev.pos.txt'
        dev_neg_file = 'data/android/dev.neg.txt'
        android_data = android_dataset.read_annotations(
            dev_pos_file, dev_neg_file)

        android_batches = batch_utils.generate_eval_batches(
            android_ids, android_data, padding_id)

    for epoch in xrange(args.start_epoch, args.epochs):
        train_batches = batch_utils.generate_train_batches(
            corpus_ids, train_data, args.batch_size, padding_id)

        train_utils.train(args, model, embedding, optimizer, criterion,
                          train_batches, padding_id, epoch)

        map, mrr, p1, p5 = train_utils.evaluate_metrics(
            args, model, embedding, dev_batches, padding_id)

        auc = -1
        if args.android:
            auc = train_utils.evaluate_auc(args, model, embedding,
                                           android_batches, padding_id)

        is_best = auc > best_auc if args.android else mrr > best_mrr
        best_mrr = max(mrr, best_mrr)
        best_auc = max(auc, best_auc)
        save(
            args, {
                'epoch': epoch + 1,
                'arch': 'lstm',
                'state_dict': model.state_dict(),
                'best_mrr': best_mrr,
                'best_auc': best_auc,
            }, is_best)