示例#1
0
    # Hyperparameter
    epochs = 100
    lr = 0.01

    train_loader, valid_loader = data.load_data(batch_size=64)
    print("Train samples: %d" % len(train_loader.dataset))
    print("Valid samples: %d" % len(valid_loader.dataset))
    model = model.model()
    model = model.to(device)

    criterion_lss1 = nn.BCELoss()
    criterion_lss2 = nn.KLDivLoss(reduction='batchmean')
    criterion_ce = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    time_str = time.strftime("%m_%d-%Hh%Mm%Ss", time.localtime())
    file = open("../log/%s.csv" % time_str, 'w')
    writer = csv.writer(file)
    headers = [
        "train_loss", "train_acc", "train_lsl", "train_lss_1", "train_lss_2",
        "train_lsd", "valid_loss", "valid_acc", "valid_lsl", "valid_lss_1",
        "valid_lss_2", "valid_lsd"
    ]

    best_acc = 0.0
    for epoch in range(epochs):
        print("-" * 5 + "Epoch:  %3d/%3d" % (epoch + 1, epochs) + "-" * 5)
        train_result = train()
        valid_result = valid()
示例#2
0
def train(model, optimizer, scheduler, global_step, train_dataset, dev_dataset,
          opt, collator, best_eval_loss):

    if opt.is_main:
        try:
            tb_logger = torch.utils.tensorboard.SummaryWriter(
                Path(opt.checkpoint_dir) / opt.name)
        except:
            tb_logger = None
            logger.warning('Tensorboard is not available.')
    train_sampler = DistributedSampler(
        train_dataset) if opt.is_distributed else RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=opt.per_gpu_batch_size,
                                  drop_last=True,
                                  num_workers=10,
                                  collate_fn=collator)

    loss, curr_loss = 0.0, 0.0
    epoch = 1
    model.train()
    while global_step < opt.total_steps:
        if opt.is_distributed > 1:
            train_sampler.set_epoch(epoch)
        epoch += 1
        for i, batch in enumerate(train_dataloader):
            global_step += 1
            (idx, question_ids, question_mask, passage_ids, passage_mask,
             gold_score) = batch
            _, _, _, train_loss = model(
                question_ids=question_ids.cuda(),
                question_mask=question_mask.cuda(),
                passage_ids=passage_ids.cuda(),
                passage_mask=passage_mask.cuda(),
                gold_score=gold_score.cuda(),
            )

            train_loss.backward()

            if global_step % opt.accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
                optimizer.step()
                scheduler.step()
                model.zero_grad()

            train_loss = src.util.average_main(train_loss, opt)
            curr_loss += train_loss.item()

            if global_step % opt.eval_freq == 0:
                eval_loss, inversions, avg_topk, idx_topk = evaluate(
                    model, dev_dataset, collator, opt)
                if eval_loss < best_eval_loss:
                    best_eval_loss = eval_loss
                    if opt.is_main:
                        src.util.save(model, optimizer, scheduler, global_step,
                                      best_eval_loss, opt, dir_path,
                                      'best_dev')
                model.train()
                if opt.is_main:
                    log = f"{global_step} / {opt.total_steps}"
                    log += f" -- train: {curr_loss/opt.eval_freq:.6f}"
                    log += f", eval: {eval_loss:.6f}"
                    log += f", inv: {inversions:.1f}"
                    log += f", lr: {scheduler.get_last_lr()[0]:.6f}"
                    for k in avg_topk:
                        log += f" | avg top{k}: {100*avg_topk[k]:.1f}"
                    for k in idx_topk:
                        log += f" | idx top{k}: {idx_topk[k]:.1f}"
                    logger.info(log)

                    if tb_logger is not None:
                        tb_logger.add_scalar("Evaluation", eval_loss,
                                             global_step)
                        tb_logger.add_scalar("Training",
                                             curr_loss / (opt.eval_freq),
                                             global_step)
                    curr_loss = 0

            if opt.is_main and global_step % opt.save_freq == 0:
                src.util.save(model, optimizer, scheduler, global_step,
                              best_eval_loss, opt, dir_path,
                              f"step-{global_step}")
            if global_step > opt.total_steps:
                break
示例#3
0
# Instantiate the model w/ hyperparams
vocab_sz = len(oneHotdict) + 1
output_size = 17

n_layers = 2
train_on_gpu = torch.cuda.is_available()
model = model.RNNSentiment(num_layer=config.n_layers,
                           vocab_size=vocab_sz,
                           hidden_dim=config.hidden_dim,
                           embedding_dim=config.embedding_dim)
# loss and optimization functions
lr = 0.001

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# training params
epochs = 10
counter = 0
print_every = 1000
clip = 5  # gradient clipping

# move model to GPU, if available
if (train_on_gpu):
    model.cuda()

for e in range(epochs):
    model.train()
    # batch loop
    for inputs, labels in tqdm(Train_loader, total=len(Train_loader)):
示例#4
0
文件: run.py 项目: aspgln/rcc
def run(args):

    ### Data Loading

    if args.task == 0:
        print('Task 0: MR Dataset Prediction')
        augmentor = transforms.Compose([
            transforms.Lambda(lambda x: torch.Tensor(x)),
            mrnet.torchsample.transforms.RandomRotate(25),
            mrnet.torchsample.transforms.RandomTranslate([0.11, 0.11]),
            mrnet.torchsample.transforms.RandomFlip(),
            transforms.Lambda(
                lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
        ])
        job = 'acl'
        plane = 'sagittal'
        train_ds = mrnet.mrnet_dataloader.MRDataset(
            '/data/larson2/RCC_dl/MRNet-v1.0/data/',
            job,
            plane,
            transform=augmentor,
            train=True)
        train_loader = torch.utils.data.DataLoader(train_ds,
                                                   batch_size=1,
                                                   shuffle=True,
                                                   num_workers=11,
                                                   drop_last=False)

        val_ds = mrnet.mrnet_dataloader.MRDataset(
            '/data/larson2/RCC_dl/MRNet-v1.0/data/', job, plane, train=False)
        val_loader = torch.utils.data.DataLoader(val_ds,
                                                 batch_size=1,
                                                 shuffle=- True,
                                                 num_workers=11,
                                                 drop_last=False)

    elif args.task == 1:
        print('Task 1: clear cell grade prediction')
        path = '/data/larson2/RCC_dl/new/clear_cell/'

        augmentor = transforms.Compose([
            transforms.Lambda(lambda x: torch.Tensor(x)),
            src.dataloader.Rescale(-160, 240),  # rset dynamic range
            transforms.Lambda(
                lambda x: x.repeat(3, 1, 1, 1).permute(3, 0, 1, 2)),
            #             src.dataloader.Normalize(),
            #             src.dataloader.Crop(90),
            #             src.dataloader.RandomCenterCrop(90),
            src.dataloader.RandomHorizontalFlip(),
            src.dataloader.RandomRotate(25),
            src.dataloader.Resize(256),
        ])

        augmentor2 = transforms.Compose([
            transforms.Lambda(lambda x: torch.Tensor(x)),
            src.dataloader.Rescale(-160, 240),  # rset dynamic range
            transforms.Lambda(
                lambda x: x.repeat(3, 1, 1, 1).permute(3, 0, 1, 2)),
            #         src.dataloader.Normalize(),
            #         src.dataloader.Crop(90),
            src.dataloader.Resize(256),
        ])

        train_ds = src.dataloader.RCCDataset_h5(path,
                                                mode='train',
                                                transform=augmentor)
        train_loader = DataLoader(train_ds,
                                  batch_size=1,
                                  shuffle=True,
                                  num_workers=1,
                                  drop_last=False)

        val_ds = src.dataloader.RCCDataset_h5(path,
                                              mode='val',
                                              transform=augmentor2)
        val_loader = DataLoader(val_ds,
                                batch_size=1,
                                shuffle=True,
                                num_workers=1,
                                drop_last=False)
        print(f'train size: {len(train_loader)}')
        print(f'val size: {len(val_loader)}')

        pos_weight = args.weight

    ### Some Checkers
    print('Summary: ')

    print(f'\ttrain size: {len(train_loader)}')
    print(f'\tval size: {len(val_loader)}')
    print('\tDatatype = ', train_ds[1][0].dtype)
    print('\tMin = ', train_ds[1][0].min())
    print('\tMax = ', train_ds[1][0].max())
    print('\tInput size', train_ds[0][0].shape)
    print('\tweight = ', args.weight)

    ### Some trackers
    log_root_folder = "/data/larson2/RCC_dl/logs/"

    now = datetime.now()
    now = now.strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(
        log_root_folder,
        f"task_{args.task}_{args.prefix_name}_model{args.model}_{now}")
    os.makedirs(logdir)
    print(f'logdir = {logdir}')

    writer = SummaryWriter(logdir)

    ### Model Construction

    ## Select Model
    if args.model == 1:
        model = src.model.MRNet()
    elif args.model == 2:
        model = src.model.MRNet2()
    elif args.model == 3:
        model = src.model.MRNetBN()
    elif args.model == 4:
        model = src.model.MRResNet()
    elif args.model == 5:
        model = src.model.MRNetScratch()
    elif args.model == 6:
        model = src.model.TDNet()
    else:
        print('Invalid model name')
        return

    ## Weight Initialization

    ## Training Stretegy
    device = torch.device(
        "cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
    print('\tCuda:', torch.cuda.is_available(), f'\n\tdevice = {device}')

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.1)

    if args.lr_scheduler == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               patience=3,
                                                               factor=.3,
                                                               threshold=1e-4,
                                                               verbose=True)
    elif args.lr_scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=3,
                                                    gamma=args.gamma)

    model = model.to(device)

    ### Ready?
    best_val_loss = float('inf')
    best_val_auc = float(0)
    iteration_change_loss = 0
    t_start_training = time.time()

    ### Here we go
    for epoch in range(args.epochs):
        current_lr = src.train3d.get_lr(optimizer)

        t_start = time.time()

        train_loss, train_auc = src.train3d.train_model(
            model, train_loader, device, epoch, args.epochs, optimizer, writer,
            current_lr, args.log_every, args.weight)
        val_loss, val_auc = src.train3d.evaluate_model(
            model,
            val_loader,
            device,
            epoch,
            args.epochs,
            writer,
            current_lr,
            args.log_every,
        )

        if args.lr_scheduler == 'plateau':
            scheduler.step(val_loss)
        elif args.lr_scheduler == 'step':
            scheduler.step()

        t_end = time.time()
        delta = t_end - t_start

        print(
            "train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s"
            .format(train_loss, train_auc, val_loss, val_auc, delta))

        iteration_change_loss += 1
        print('-' * 30)

        model_root_dir = "/data/larson2/RCC_dl/models/"

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            if bool(args.save_model):
                file_name = f'task_{args.task}_model_{args.model}_{args.prefix_name}_val_auc_{val_auc:0.4f}_train_auc_{train_auc:0.4f}_epoch_{epoch+1}_weight_{args.weight}_lr_{args.lr}_gamma_{args.gamma}_lrsche_{args.lr_scheduler}.pth'
                #                 for f in os.listdir(model_root_dir):
                #                     if  (args.prefix_name in f):
                #                         os.remove(os.path.join(model_root_dir, f))
                torch.save(model, os.path.join(model_root_dir, file_name))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            iteration_change_loss = 0

        if iteration_change_loss == args.patience:
            print(
                'Early stopping after {0} iterations without the decrease of the val loss'
                .format(iteration_change_loss))
            break

    t_end_training = time.time()
    print(f'training took {t_end_training - t_start_training} s')
示例#5
0
def run(args):
    print('Task 1: clear cell grade prediction')
    path = '/data/larson2/RCC_dl/new/clear_cell/'

    transform = {
        'train':
        transforms.Compose([
            transforms.Lambda(lambda x: torch.Tensor(x)),
            src.dataloader.Rescale(-160, 240,
                                   zero_center=True),  # rset dynamic range
            transforms.Lambda(
                lambda x: x.repeat(3, 1, 1, 1).permute(3, 0, 1, 2)),
            #     src.dataloader.Normalize(),
            #     src.dataloader.Crop(110),
            #     src.dataloader.RandomCenterCrop(90),
            src.dataloader.RandomHorizontalFlip(),
            #     src.dataloader.RandomRotate(25),
            src.dataloader.Resize(256)
        ]),
        'val':
        transforms.Compose([
            transforms.Lambda(lambda x: torch.Tensor(x)),
            src.dataloader.Rescale(-160, 240,
                                   zero_center=True),  # rset dynamic range
            transforms.Lambda(
                lambda x: x.repeat(3, 1, 1, 1).permute(3, 0, 1, 2)),
            #       src.dataloader.Normalize(),
            #       src.dataloader.Crop(90),
            src.dataloader.Resize(256)
        ])
    }

    my_dataset = {
        'train':
        src.dataloader.RCCDataset_h5(path,
                                     mode='train',
                                     transform=transform['train']),
        'val':
        src.dataloader.RCCDataset_h5(path,
                                     mode='val',
                                     transform=transform['train'])
    }

    my_loader = {
        x: DataLoader(my_dataset[x], batch_size=1, shuffle=True, num_workers=4)
        for x in ['train', 'val']
    }

    print('train size: ', len(my_loader['train']))
    print('train size: ', len(my_loader['val']))

    ### Some Checkers
    print('Summary: ')
    print('\ttrain size: ', len(my_loader['train']))
    print('\ttrain size: ', len(my_loader['val']))
    print('\tDatatype = ', next(iter(my_loader['train']))[0].dtype)
    print('\tMin = ', next(iter(my_loader['train']))[0].min())
    print('\tMax = ', next(iter(my_loader['train']))[0].max())
    print('\tInput size', next(iter(my_loader['train']))[0].shape)
    #     print('\tweight = ', args.weight)

    ### Tensorboard Log Setup
    log_root_folder = "/data/larson2/RCC_dl/logs/"
    now = datetime.now()
    now = now.strftime("%Y%m%d-%H%M%S")
    logdir = os.path.join(
        log_root_folder,
        f"{now}_model_{args.model}_{args.prefix_name}_epoch_{args.epochs}_weight_{args.weight}_lr_{args.lr}_gamma_{args.gamma}_lrsche_{args.lr_scheduler}_{now}"
    )
    #     os.makedirs(logdir)
    print(f'\tlogdir = {logdir}')

    writer = SummaryWriter(logdir)

    ### Model Selection

    device = torch.device(
        "cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")

    model = src.model.TDNet()
    model = model.to(device)

    writer.add_graph(model, my_dataset['train'][0][0].to(device))

    print('\tCuda:', torch.cuda.is_available(), f'\n\tdevice = {device}')

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.1)

    if args.lr_scheduler == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               patience=3,
                                                               factor=.3,
                                                               threshold=1e-4,
                                                               verbose=True)
    elif args.lr_scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=3,
                                                    gamma=args.gamma)

    pos_weight = torch.FloatTensor([args.weight]).to(device)
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    ### Ready?
    best_val_loss = float('inf')
    best_val_auc = float(0)
    best_model_wts = copy.deepcopy(model.state_dict())
    iteration_change_loss = 0
    t_start_training = time.time()

    ### Here we go
    for epoch in range(args.epochs):
        current_lr = get_lr(optimizer)
        t_start = time.time()

        epoch_loss = {'train': 0., 'val': 0.}
        epoch_corrects = {'train': 0., 'val': 0.}

        epoch_acc = 0.0
        epoch_AUC = 0.0

        for phase in ['train', 'val']:
            if phase == 'train':
                if args.lr_scheduler == "step":
                    scheduler.step()
                model.train()
            else:
                model.eval()

            running_losses = []
            running_corrects = 0.
            y_trues = []
            y_probs = []
            y_preds = []

            print('lr: ', current_lr)
            for i, (inputs, labels, header) in enumerate(my_loader[phase]):
                optimizer.zero_grad()

                inputs = inputs.to(device)
                labels = labels.to(device)

                # forward
                # track history only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs.float())  # raw logits
                    probs = torch.sigmoid(
                        outputs)  # [0, 1] probability, shape = s * 1
                    preds = torch.round(
                        probs
                    )  # 0 or 1, shape = s * 1, prediction for each slice
                    pt_pred, _ = torch.mode(
                        preds, 0
                    )  # take majority vote, shape = 1, prediction for each patient

                    count0 = (preds == 0).sum().float()
                    count1 = (preds == 1).sum().float()
                    pt_prob = count1 / (preds.shape[0])

                    # convert label to slice level
                    loss = criterion(outputs, labels.repeat(
                        inputs.shape[1], 1))  # inputs shape = 1*s*3*256*256

                    # backward + optimize only if in training phases
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # multiple loss by slice num per batch?
                running_losses.append(loss.item())  # * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                y_trues.append(int(labels.item()))
                y_probs.append(pt_prob.item())  # use ratio to get probability
                y_preds.append(pt_pred.item())

                writer.add_scalar(f'{phase}/Loss', loss.item(),
                                  epoch * len(my_loader[phase]) + i)
                writer.add_pr_curve('{phase}pr_curve', y_trues, y_probs, 0)

                if (i % args.log_every == 0) & (i > 0):
                    print(
                        'Epoch: {0}/{1} | Single batch number : {2}/{3} | avg loss:{4} | Acc: {5:.4f} | lr: {6}'
                        .format(epoch + 1, args.epochs, i,
                                len(my_loader[phase]),
                                np.round(np.mean(running_losses), 4),
                                (running_corrects / len(my_loader[phase])),
                                current_lr))

            # epoch statistics
            epoch_loss[phase] = np.round(np.mean(running_losses), 4)
            epoch_corrects[phase] = (running_corrects / len(my_loader[phase]))

            cm = confusion_matrix(y_trues, y_preds, labels=[0, 1])
            src.helper.print_cm(cm, ['0', '1'])
            sens, spec, acc = src.helper.compute_stats(y_trues, y_preds)
            print('sens: {:.4f}'.format(sens))
            print('spec: {:.4f}'.format(spec))
            print('acc:  {:.4f}'.format(acc))
            print()

        print(
            '\ Summary  train loss: {0} | val loss: {1} | train acc: {2:.4f} | val acc: {3:.4f}'
            .format(epoch_loss['train'], epoch_loss['val'],
                    epoch_corrects['train'], epoch_corrects['val']))
        print('-' * 30)
示例#6
0
def train(model, optimizer, scheduler, step, train_dataset, eval_dataset, opt, collator, best_dev_em, checkpoint_path):

    if opt.is_main:
        try:
            tb_logger = torch.utils.tensorboard.SummaryWriter(Path(opt.checkpoint_dir)/opt.name)
        except:
            tb_logger = None
            logger.warning('Tensorboard is not available.')

    torch.manual_seed(opt.global_rank + opt.seed) #different seed for different sampling depending on global_rank
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=opt.per_gpu_batch_size,
        drop_last=True,
        num_workers=10,
        collate_fn=collator
    )

    loss, curr_loss = 0.0, 0.0
    epoch = 1
    model.train()
    while step < opt.total_steps:
        epoch += 1
        for i, batch in enumerate(train_dataloader):
            step += 1
            (idx, labels, _, context_ids, context_mask) = batch

            train_loss = model(
                input_ids=context_ids.cuda(),
                attention_mask=context_mask.cuda(),
                labels=labels.cuda()
            )[0]

            train_loss.backward()

            if step % opt.accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
                optimizer.step()
                scheduler.step()
                model.zero_grad()

            train_loss = src.util.average_main(train_loss, opt)
            curr_loss += train_loss.item()

            if step % opt.eval_freq == 0:
                dev_em = evaluate(model, eval_dataset, tokenizer, collator, opt)
                model.train()
                if opt.is_main:
                    if dev_em > best_dev_em:
                        best_dev_em = dev_em
                        src.util.save(model, optimizer, scheduler, step, best_dev_em,
                                  opt, checkpoint_path, 'best_dev')
                    log = f"{step} / {opt.total_steps} |"
                    log += f"train: {curr_loss/opt.eval_freq:.3f} |"
                    log += f"evaluation: {100*dev_em:.2f}EM |"
                    log += f"lr: {scheduler.get_last_lr()[0]:.5f}"
                    logger.info(log)
                    curr_loss = 0
                    if tb_logger is not None:
                        tb_logger.add_scalar("Evaluation", dev_em, step)
                        tb_logger.add_scalar("Training", curr_loss / (opt.eval_freq), step)

            if opt.is_main and step % opt.save_freq == 0:
                src.util.save(model, optimizer, scheduler, step, best_dev_em,
                          opt, checkpoint_path, f"step-{step}")
            if step > opt.total_steps:
                break
示例#7
0
def train_keypoint_rcnn(data=None,
                        epochs: int = None,
                        lr: float = 1e-5,
                        pretrained: str = None):

    model = src.model.keypoint_rcnn

    if not isinstance(pretrained, str) and pretrained is not None:
        raise ValueError(
            f'Argument "pretrained" must be a path to a valid mask file, '
            f'not {pretrained} with type {type(pretrained)}')
    if epochs is None:
        epochs = 500

    if pretrained is not None:
        print('Loading...')
        model.load_state_dict(torch.load(pretrained))

    if torch.cuda.is_available(): device = 'cuda:0'
    else: device = 'cpu'

    # tests = KeypointData('/media/DataStorage/Dropbox (Partners HealthCare)/DetectStereocillia/data/keypoint_train_data')
    model = model.train().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    for e in range(epochs):
        epoch_loss = []
        time_1 = time.clock_gettime_ns(1)
        for image, data_dict in data:
            for key in data_dict:
                data_dict[key] = data_dict[key].to(device)
            assert image.shape[1] == 3

            optimizer.zero_grad()
            loss = model(image.to(device), [data_dict])
            losses = 0
            for key in loss:
                losses += loss[key]
            losses.backward()
            epoch_loss.append(losses.item())
            optimizer.step()
        time_2 = time.clock_gettime_ns(1)

        delta_time = np.round((np.abs(time_2 - time_1) / 1e9) / 60, decimals=2)

        #  --------- This is purely to output a nice bar for training --------- #
        if e % 5 == 0:
            if e > 0:
                print('\b \b' * len(out_str), end='')
            progress_bar = '[' + '█' * +int(np.round(e / epochs, decimals=1) * 10) + \
                           ' ' * int(
                (10 - np.round(e / epochs, decimals=1) * 10)) + f'] {np.round(e / epochs, decimals=3)}%'

            out_str = f'epoch: {e} ' + progress_bar + f'| time remaining: {delta_time * (epochs-e)} min | epoch loss: {torch.tensor(epoch_loss).mean().item()}'
            print(out_str, end='')

        # If its the final epoch print out final string
        elif e == epochs - 1:
            print('\b \b' * len(out_str), end='')
            progress_bar = '[' + '█' * 10 + f'] {1.0}'
            out_str = f'epoch: {epochs} ' + progress_bar + f'| time remaining: {0} min | epoch loss: {torch.tensor(epoch_loss).mean().item()}'
            print(out_str)

        torch.save(model.state_dict(), 'models/keypoint_rcnn.mdl')

    model.eval()
    out = model(image.unsqueeze(0).cuda())
示例#8
0
def work_process(process, args):
    #get the rank of current process within world
    args.rank = args.nr * args.gpus * args.process_num + process
    #get the rank of gpu to which current process will be sent
    gpu = process // args.process_num + args.st

    os.environ['MASTER_ADDR'] = '202.38.73.168'
    os.environ['MASTER_PORT'] = '6632'
    #set  parameters of distributed environment
    torch.cuda.set_device(gpu)
    device = torch.device("cuda:" + str(gpu))
    distributed.init_process_group(backend='nccl',
                                   world_size=args.world_size,
                                   rank=args.rank)
    #move model to specified gpu
    torch.manual_seed(0)
    model = src.model.get_model(args).to(device)

    #check whether to resume from checkpoint or not
    start_epoch = 0
    best_acc = 0
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/' + args.model + "_init.pth")
        model.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']
    '''
    #set parameters of each worker to be same
    torch.cuda.manual_seed(1)
    for p in model.parameters():
        if p.requires_grad==False:
            continue
        p.data=torch.randn_like(p.data)*0.1
    '''
    #create disributed optimizer
    optimizer = DSGD(model.parameters(),
                     model,
                     update_period=args.period,
                     lr=args.lr,
                     local=args.local,
                     args=args)
    #load dataset
    train_dataset, test_dataset = get_dataset(args.dataset, args)

    train_sampler = UnshuffleDistributedSampler(train_dataset,
                                                num_replicas=args.world_size,
                                                rank=args.rank,
                                                cluster_data=args.cluster_data,
                                                Dirichlet=args.Dirichlet)
    train_loader = data.DataLoader(train_dataset,
                                   args.batch_size,
                                   sampler=train_sampler)
    test_loader = data.DataLoader(test_dataset,
                                  args.batch_size,
                                  shuffle=True,
                                  num_workers=1)

    #train model and record related result
    trainer = Trainer(model, optimizer, train_loader, test_loader, device)
    trainer.fit(best_acc, start_epoch, args.epochs, args)

    #save command line arguements
    if args.rank == 0:
        path = "Args/{}_{}_{}_{}.pkl".format(args.model, args.dataset,
                                             iid(args), get_alg_name(args))
        with open(path, 'wb') as outfile:
            pickle.dump(args, outfile)