Пример #1
0
def train(args, model: CharRNN, step, epoch, corpus, char_to_id, criterion,
          model_file):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    batch_chars = args.window_size * args.batch_size
    save = lambda ep: torch.save(
        {
            'state': model.state_dict(),
            'epoch': ep,
            'step': step,
        }, str(model_file))
    log = Path(args.root).joinpath('train.log').open('at', encoding='utf8')
    for epoch in range(epoch, args.n_epochs + 1):
        try:
            losses = []
            n_iter = args.epoch_batches or (len(corpus) // batch_chars)
            report_each = min(10, n_iter - 1)
            tr = tqdm.tqdm(total=n_iter * batch_chars)
            tr.set_description('Epoch {}'.format(epoch))
            model.train()
            for i in range(n_iter):
                inputs, targets = random_batch(
                    corpus,
                    batch_size=args.batch_size,
                    window_size=args.window_size,
                    char_to_id=char_to_id,
                )
                loss = train_model(model, criterion, optimizer, inputs,
                                   targets)
                step += 1
                losses.append(loss)
                tr.update(batch_chars)
                mean_loss = np.mean(losses[-report_each:])
                tr.set_postfix(loss=mean_loss)
                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)
            tr.close()
            save(ep=epoch + 1)
        except KeyboardInterrupt:
            print('\nGot Ctrl+C, saving checkpoint...')
            save(ep=epoch)
            print('done.')
            return
        if args.valid_corpus:
            valid_result = validate(args, model, criterion, char_to_id)
            write_event(log, step, **valid_result)
    print('Done training for {} epochs'.format(args.n_epochs))
def train(args, model: nn.Module, criterion, *, params,
          train_loader, valid_loader, init_optimizer, use_cuda,
          n_epochs=None, patience=2, max_lr_changes=3) -> bool:
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)#in case params is not a list
    #add params into optimizer
    optimizer = init_optimizer(params, lr)

    #model load/save path
    run_root = Path(args.run_root)

    model_path = Path(str(run_root) + '/' + 'model.pt')

    if model_path.exists():
        print('loading existing weights from model.pt')
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
        best_f1 = state['best_f1']
    else:
        epoch = 1
        step = 0
        best_valid_loss = 0.0#float('inf')
        best_f1 = 0


    lr_changes = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': best_valid_loss,
        'best_f1': best_f1
    }, str(model_path))

    save_where = lambda ep,svpath: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': best_valid_loss,
        'best_f1': best_f1
    }, str(svpath))

    report_each = 100
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    valid_f1s = []
    lr_reset_epoch = epoch

    #epoch loop
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(total=(args.epoch_size or
                              (nPosTr+nNegTr)*len(train_loader) * args.batch_size))

        if epoch >= 20 and epoch%2==0:
            lr = lr * 0.9
            adjust_learning_rate(optimizer, lr)
            print('lr updated to %0.8f'%lr)

        tq.set_description('Epoch %d, lr %0.8f'%(epoch,lr))
        losses = []
        tl = train_loader
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
        try:
            mean_loss = 0

            for i, batch_dat in enumerate(tl):#enumerate() turns tl into index, ele_of_tl
                featComp = batch_dat['feat_comp']
                featLoc = batch_dat['feat_loc']
                featId = batch_dat['feat_id']
                featEnsemble = batch_dat['feat_ensemble_score']
                targets = batch_dat['target']

                # print(featComp.shape,featLoc.shape,batch_dat['feat_comp_dim'],batch_dat['feat_loc_dim'])

                if use_cuda:
                    featComp, featLoc, targets, featId,featEnsemble = featComp.cuda(), featLoc.cuda(),targets.cuda(),featId.cuda(),featEnsemble.cuda()

                # common_feat_comp, common_feat_loc, feat_comp_loc, outputs = model(feat_comp=featComp, feat_loc=featLoc)
                model_output = model(feat_comp = featComp, feat_loc = featLoc, id_loc = featId,feat_ensemble_score=featEnsemble)
                outputs = model_output['outputs']


                # outputs = outputs.squeeze()

                # loss1 = softmax_loss(outputs, targets)
                # loss2 = TripletLossV1(margin=0.5)(feats,targets)
                # loss1 = criterion(outputs,targets)


                if args.cos_sim_loss:
                    out_comp_feat = model_output['comp_feat']
                    out_loc_feat = model_output['loc_feat']
                    cos_targets = 2*targets.float()-1.0
                    loss = criterion(out_comp_feat,out_loc_feat,cos_targets)
                    lossType = 'cosine'
                else:
                    loss = softmax_loss(outputs, targets)
                    lossType = 'softmax'

                batch_size = featComp.size(0)

                (batch_size * loss).backward()
                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')

                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)

            write_event(log, step, loss=mean_loss)
            tq.close()
            print('saving')
            save(epoch + 1)
            print('validation')
            valid_metrics = validation(model, criterion, valid_loader, use_cuda, lossType=lossType)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_top1 = valid_metrics['valid_top1']
            valid_roc = valid_metrics['auc']
            valid_losses.append(valid_loss)


            #tricky
            valid_loss = valid_roc
            if valid_loss > best_valid_loss:#roc:bigger is better
                best_valid_loss = valid_loss
                shutil.copy(str(model_path), str(run_root) + '/model_loss_best.pt')

        except KeyboardInterrupt:
            tq.close()
            # print('Ctrl+C, saving snapshot')
            # save(epoch)
            # print('done.')

            return False
    return True
Пример #3
0
def train(args,
          model,
          optimizer,
          scheduler,
          tokenizer,
          ner_index,
          *,
          train_loader,
          valid_df,
          valid_loader,
          epoch_length,
          n_epochs=None):
    n_epochs = n_epochs or args.n_epochs

    run_root = Path('../experiments/' + args.run_root)
    model_path = run_root / ('tagger_model-%d.pt' % args.fold)
    best_model_path = run_root / ('best-model-%d.pt' % args.fold)
    if best_model_path.exists():
        state, best_valid_score = load_model(model, best_model_path)
        start_epoch = state['epoch']
        best_epoch = start_epoch
    else:
        best_valid_score = 0
        start_epoch = 0
        best_epoch = 0
    step = 0
    criterion = CrossEntropyLoss().cuda()
    report_each = 10000
    log = run_root.joinpath('train-%d.log' % args.fold).open('at',
                                                             encoding='utf8')

    for epoch in range(start_epoch, start_epoch + n_epochs):
        model.train()

        tq = tqdm.tqdm(total=epoch_length)
        losses = []

        mean_loss = 0
        device = torch.device("cuda", 0)
        for i, (ori_sen, token, token_type, start, end, insert_pos, start_ner,
                end_ner) in enumerate(train_loader):
            input_mask = (token > 0).to(device)
            token, input_mask, token_type, start, end, insert_pos, start_ner, end_ner = \
                token.to(device), input_mask.to(device), token_type.to(device), start.to(
                    device), end.to(device), insert_pos.to(device), start_ner.to(device), end_ner.to(device)
            outputs = model(input_ids=token,
                            attention_mask=input_mask,
                            token_type_ids=token_type,
                            start=start,
                            end=end,
                            insert_pos=insert_pos,
                            start_ner=start_ner,
                            end_ner=end_ner)

            loss = outputs[0]
            if (i + 1) % args.step == 0:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
            else:
                loss.backward()

            tq.update(args.batch_size)
            losses.append(loss.item() * args.step)
            mean_loss = np.mean(losses[-report_each:])
            tq.set_postfix(loss=f'{mean_loss:.5f}')
            lr = get_learning_rate(optimizer)
            tq.set_description(f'Epoch {epoch}, lr {lr:.6f}')
            if i and i % report_each == 0:
                write_event(log, step, loss=mean_loss)
            # break
        write_event(log, step, epoch=epoch, loss=mean_loss)
        tq.close()

        valid_metrics = validate(model, valid_loader, valid_df, args,
                                 tokenizer, ner_index)
        # write_event(log, step, **valid_metrics)
        current_score = valid_metrics['rouge-1']['f']
        if current_score > best_valid_score:
            print('save success')
            save_model(model, epoch, step, mean_loss, model_path)
            best_valid_score = current_score
    return True
def train(args, model: nn.Module, optimizer, scheduler, criterion, *,
          train_loader, valid_df, valid_loader, epoch_length, patience=1,
          n_epochs=None) -> bool:
    n_epochs = n_epochs or args.n_epochs

    run_root = Path('../experiments/' + args.run_root)
    model_path = run_root / ('model-%d.pt' % args.fold)
    best_model_path = run_root / ('best-model-%d.pt' % args.fold)
    if best_model_path.exists():
        state, best_valid_score = load_model(model, best_model_path)
        start_epoch = state['epoch']
        best_epoch = start_epoch
    else:
        best_valid_score = 0
        start_epoch = 0
        best_epoch = 0
    step = 0

    if args.mode == "train_all":
        current_score = 0.95

    save = lambda ep: torch.save({
        'model': model.module.state_dict() if args.multi_gpu == 1 else model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': current_score
    }, str(model_path))
    #
    report_each = 10000
    log = run_root.joinpath('train-%d.log' % args.fold).open('at', encoding='utf8')

    for epoch in range(start_epoch, start_epoch + n_epochs):
        model.train()

        lr = get_learning_rate(optimizer)
        tq = tqdm.tqdm(total=epoch_length)
        tq.set_description(f'Epoch {epoch}, lr {lr}')
        losses = []

        mean_loss = 0
        for i, (inputs, _, targets, weights) in enumerate(train_loader):
            attention_mask = (inputs > 0).cuda()
            inputs, targets, weights = inputs.cuda(), targets.cuda(), weights.unsqueeze(1).cuda()

            outputs = model(inputs, attention_mask=attention_mask, labels=None)

            loss = criterion(outputs, targets) / args.step
            batch_size = inputs.size(0)
            if (i + 1) % args.step == 0:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            else:
                with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss:
                    scaled_loss.backward()

            tq.update(batch_size)
            losses.append(loss.item() * args.step)
            mean_loss = np.mean(losses[-report_each:])
            tq.set_postfix(loss=f'{mean_loss:.5f}')
            if i and i % report_each == 0:
                write_event(log, step, loss=mean_loss)

        write_event(log, step, epoch=epoch, loss=mean_loss)
        tq.close()

        if args.mode == "train":
            valid_metrics = validation(model, criterion, valid_df, valid_loader, args)
            write_event(log, step, **valid_metrics)
            current_score = valid_metrics['score']
        save(epoch + 1)
        if scheduler is not None and args.mode == "train":
            scheduler.step(current_score)

        if args.mode == "train":
            if current_score > best_valid_score:
                best_valid_score = current_score
                shutil.copy(str(model_path), str(best_model_path))
                best_epoch = epoch
            else:
                pass
    return True
Пример #5
0
def train(args):
    print("Traning")

    print("Prepaing data")
    masks = pd.read_csv(os.path.join(args.dataset_dir, args.train_masks))
    unique_img_ids = get_unique_img_ids(masks, args)
    train_df, valid_df = get_balanced_train_valid(masks, unique_img_ids, args)

    if args.stage == 0:
        train_shape = (256, 256)
        batch_size = args.stage0_batch_size
        extra_epoch = args.stage0_epochs
    elif args.stage == 1:
        train_shape = (384, 384)
        batch_size = args.stage1_batch_size
        extra_epoch = args.stage1_epochs
    elif args.stage == 2:
        train_shape = (512, 512)
        batch_size = args.stage2_batch_size
        extra_epoch = args.stage2_epochs
    elif args.stage == 3:
        train_shape = (768, 768)
        batch_size = args.stage3_batch_size
        extra_epoch = args.stage3_epochs

    print("Stage {}".format(args.stage))

    train_transform = DualCompose([
        Resize(train_shape),
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        Shift(),
        Transpose(),
        # ImageOnly(RandomBrightness()),
        # ImageOnly(RandomContrast()),
    ])
    val_transform = DualCompose([
        Resize(train_shape),
    ])

    train_dataloader = make_dataloader(train_df,
                                       args,
                                       batch_size,
                                       args.shuffle,
                                       transform=train_transform)
    val_dataloader = make_dataloader(valid_df,
                                     args,
                                     batch_size // 2,
                                     args.shuffle,
                                     transform=val_transform)

    # Build model
    model = UNet()
    optimizer = Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=args.decay_fr, gamma=0.1)
    if args.gpu and torch.cuda.is_available():
        model = model.cuda()

    # Restore model ...
    run_id = 4

    model_path = Path('model_{run_id}.pt'.format(run_id=run_id))
    if not model_path.exists() and args.stage > 0:
        raise ValueError(
            'model_{run_id}.pt does not exist, initial train first.'.format(
                run_id=run_id))
    if model_path.exists():
        state = torch.load(str(model_path))
        last_epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restore model, epoch {}, step {:,}'.format(last_epoch, step))
    else:
        last_epoch = 1
        step = 0

    log_file = open('train_{run_id}.log'.format(run_id=run_id),
                    'at',
                    encoding='utf8')

    loss_fn = LossBinary(jaccard_weight=args.iou_weight)

    valid_losses = []

    print("Start training ...")
    for _ in range(last_epoch):
        scheduler.step()

    for epoch in range(last_epoch, last_epoch + extra_epoch):
        scheduler.step()
        model.train()
        random.seed()
        tq = tqdm(total=len(train_dataloader) * batch_size)
        tq.set_description('Run Id {}, Epoch {} of {}, lr {}'.format(
            run_id, epoch, last_epoch + extra_epoch,
            args.lr * (0.1**(epoch // args.decay_fr))))
        losses = []
        try:
            mean_loss = 0.
            for i, (inputs, targets) in enumerate(train_dataloader):
                inputs, targets = torch.tensor(inputs), torch.tensor(targets)
                if args.gpu and torch.cuda.is_available():
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                loss.backward()
                optimizer.step()

                step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-args.log_fr:])
                tq.set_postfix(loss="{:.5f}".format(mean_loss))

                if i and (i % args.log_fr) == 0:
                    write_event(log_file, step, loss=mean_loss)
            write_event(log_file, step, loss=mean_loss)
            tq.close()
            save_model(model, epoch, step, model_path)

            valid_metrics = validation(args, model, loss_fn, val_dataloader)
            write_event(log_file, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)

        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save_model(model, epoch, step, model_path)
            print('Terminated.')
    print('Done.')
Пример #6
0
                final_color = colors[0]
                final_mask = masks[0]
                for i in range(inputs_array.shape[0] - 1):
                    final_color = cv2.hconcat((final_color, colors[i + 1]))
                    final_mask = cv2.hconcat((final_mask, masks[i + 1]))
                final_mask = cv2.cvtColor(final_mask, cv2.COLOR_GRAY2BGR)
                final = cv2.vconcat((final_color, final_mask))
                # cv2.imwrite(str(root / 'color_images.png'), np.uint8(255*(final_color * 0.5 + 0.5)))
                # cv2.imshow("rgb", final * 0.5 + 0.5)
                # cv2.imshow("mask", final_mask * 0.5 + 0.5)
                final = cv2.cvtColor(final, cv2.COLOR_BGR2RGB)
                cv2.imwrite(
                    str(root /
                        'generated_mask_{epoch}.png'.format(epoch=epoch)),
                    np.uint8(255 * (final * 0.5 + 0.5)))
                cv2.imshow("generated", final * 0.5 + 0.5)
                cv2.waitKey(10)

        utils.write_event(log, step, Dloss=mean_D_loss)
        utils.write_event(log, step, Gloss=mean_G_loss)
        tq.close()

    except KeyboardInterrupt:
        cv2.destroyAllWindows()
        tq.close()
        print('Ctrl+C, saving snapshot')
        save(epoch, netD, D_model_path)
        save(epoch, netG, G_model_path)
        print('done.')
        exit()
Пример #7
0
def train(args, model: nn.Module, criterion, *, params,
          train_loader, valid_loader, init_optimizer, use_cuda,
          n_epochs=None, patience=2, max_lr_changes=2, writer=SummaryWriter()) -> bool:
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)
    optimizer = init_optimizer(params, lr)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5)

    run_root = Path(args.run_root)
    model_path = run_root / 'model.pt'
    best_model_path = run_root / 'best-model.pt'
    if model_path.exists():
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
    else:
        epoch = 1
        step = 0
        best_valid_loss = float('inf')
    lr_changes = 0
    total_step = len(train_loader)

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': best_valid_loss
    }, str(model_path))

    report_each = 10
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    lr_reset_epoch = epoch
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(total=(args.epoch_size or
                              len(train_loader) * args.batch_size))
        tq.set_description(f'Epoch {epoch}, lr {lr}')
        losses = []
        tl = train_loader
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
        try:
            mean_loss = 0
            for i, (inputs, targets) in enumerate(tl):
                # lr = adjust_learning_rate(optimizer, args.lr, step/total_step, n_epochs)
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                '''
                outputs = model(inputs)
                loss = _reduce_loss(criterion(outputs, targets))
                '''
            
                l = np.random.beta(0.2, 0.2)
                idx = torch.randperm(inputs.size(0))
                input_a, input_b = inputs, inputs[idx]
                target_a, target_b = targets, targets[idx]

                mixed_input = l * input_a + (1 - l) * input_b

                output = model(mixed_input)

                loss = l * _reduce_loss(criterion(output, target_a)) + (1 - l) * _reduce_loss(criterion(output, target_b))


                batch_size = inputs.size(0)
                
                (batch_size * loss).backward()
                
                # loss.backward()

                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')
                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)
                    writer.add_scalar('data/train_loss', mean_loss, step)
            write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1)
            valid_metrics = validation(model, criterion, valid_loader, use_cuda)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                shutil.copy(str(model_path), str(best_model_path))
            
            elif (patience and epoch - lr_reset_epoch > patience and
                  min(valid_losses[-patience:]) > best_valid_loss):
                # "patience" epochs without improvement
                lr_changes +=1
                if lr_changes > max_lr_changes:
                    break
                lr /= 5
                print(f'lr updated to {lr}')
                lr_reset_epoch = epoch
                optimizer = init_optimizer(params, lr)
            
            # scheduler.step(valid_loss)
            # lr = optimizer.param_groups[0]['lr']
            writer.add_scalar('data/valid_loss', valid_loss, step)
            writer.add_scalar('data/lr', lr, step)
            metrics = {}
            for k,v in valid_metrics.items():
                if k == 'valid_loss':
                    continue
                metrics[k] = v
            writer.add_scalars('data/metrics', metrics, step)
            
        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save(epoch)
            print('done.')
            return False
    return True
Пример #8
0
def train(args):
    relu_targets = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']
    print("Training decoder for relu_target:", args.relu_target)
    relu_target_id = relu_targets.index(args.relu_target) + 1

    if relu_target_id == 1:
        vgg = load_lua(args.vgg1)
        encoder = Encoder1(vgg)
        decoder = Decoder1()
        epochs = args.d1_epochs
        batch_size = args.d1_batch_size
    elif relu_target_id == 2:
        vgg = load_lua(args.vgg2)
        encoder = Encoder2(vgg)
        decoder = Decoder2()
        epochs = args.d2_epochs
        batch_size = args.d2_batch_size
    elif relu_target_id == 3:
        vgg = load_lua(args.vgg3)
        encoder = Encoder3(vgg)
        decoder = Decoder3()
        epochs = args.d3_epochs
        batch_size = args.d3_batch_size
    elif relu_target_id == 4:
        vgg = load_lua(args.vgg4)
        encoder = Encoder4(vgg)
        decoder = Decoder4()
        epochs = args.d4_epochs
        batch_size = args.d4_batch_size
    elif relu_target_id == 5:
        vgg = load_lua(args.vgg5)
        encoder = Encoder5(vgg)
        decoder = Decoder5()
        epochs = args.d5_epochs
        batch_size = args.d5_batch_size

    train_dataset = TrainDataset(
        os.path.join(args.dataset_dir, args.train_img_dir), args.img_size)
    val_dataset = TrainDataset(
        os.path.join(args.dataset_dir, args.val_img_dir), args.img_size)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=batch_size,
                                  pin_memory=torch.cuda.is_available())
    val_dataloader = DataLoader(val_dataset,
                                batch_size=batch_size // 2,
                                shuffle=False,
                                num_workers=batch_size // 2,
                                pin_memory=torch.cuda.is_available())

    if args.cuda and torch.cuda.is_available():
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    optimizer = Adam(decoder.parameters(), lr=args.lr)
    loss_fn = MSELoss()

    run_id = args.run_id
    model_path = Path('model_{relu_target}_{run_id}.pt'.format(
        relu_target=relu_targets[relu_target_id - 1], run_id=run_id))
    log_file = open('train_{relu_target}_{run_id}.log'.format(
        relu_target=relu_targets[relu_target_id - 1], run_id=run_id),
                    'at',
                    encoding='utf8')

    step = 0
    valid_losses = []

    for epoch in range(epochs):
        decoder.train()
        random.seed()
        tq = tqdm(total=len(train_dataloader) * batch_size)
        tq.set_description(
            'Run Id {}, Relu Target {} Epoch {} of {}, lr {}'.format(
                run_id, relu_targets[relu_target_id - 1], epoch, epochs,
                args.lr))
        losses = []
        try:
            mean_loss = 0.
            for i, input_imgs in enumerate(train_dataloader):
                if args.cuda and torch.cuda.is_available():
                    input_imgs = input_imgs.cuda()
                encoded = encoder(input_imgs)
                decoded = decoder(encoded)
                encoded_decoded = encoder(decoded)
                pixel_loss = args.pixel_weight * loss_fn(decoded, input_imgs)
                feature_loss = args.feature_weight * loss_fn(
                    encoded_decoded, encoded)
                loss = pixel_loss + feature_loss

                loss.backward()
                optimizer.step()

                step += 1

                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-args.log_fr:])
                tq.set_postfix(loss="{:.6f}".format(mean_loss))

                if i and (i % args.log_fr) == 0:
                    write_event(log_file, step, loss=mean_loss)
            write_event(log_file, step, loss=mean_loss)
            tq.close()
            save_model(decoder, relu_target_id, epoch, step, model_path)

            valid_loss = validation(args, encoder, decoder, loss_fn,
                                    val_dataloader, batch_size)
            valid_loss_metric = {'valid_loss': valid_loss}
            write_event(log_file, step, **valid_loss_metric)
            valid_losses.append(valid_loss)

        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save_model(decoder, relu_target_id, epoch, step, model_path)
            print('Terminated.')
    print('Done.')
def train(args,
          model: nn.Module,
          criterion,
          *,
          params,
          train_loader,
          valid_loader,
          init_optimizer,
          use_cuda,
          n_epochs=None,
          patience=2,
          max_lr_changes=3) -> bool:
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)  #in case params is not a list
    #add params into optimizer
    optimizer = init_optimizer(params, lr)

    #model load/save path
    run_root = Path(args.run_root)

    model_path = Path(str(run_root) + '/' + 'model.pt')

    if model_path.exists():
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
        best_f1 = state['best_f1']
    else:
        epoch = 1
        step = 0
        best_valid_loss = float('inf')
        best_f1 = 0

    lr_changes = 0

    save = lambda ep: torch.save(
        {
            'model': model.state_dict(),
            'epoch': ep,
            'step': step,
            'best_valid_loss': best_valid_loss,
            'best_f1': best_f1
        }, str(model_path))

    save_where = lambda ep, svpath: torch.save(
        {
            'model': model.state_dict(),
            'epoch': ep,
            'step': step,
            'best_valid_loss': best_valid_loss,
            'best_f1': best_f1
        }, str(svpath))

    report_each = 100
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    valid_f1s = []
    lr_reset_epoch = epoch

    #epoch loop
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(total=(args.epoch_size or TrainDatasetTriplet.tbatch() *
                              len(train_loader) * args.batch_size))

        if epoch >= 10:
            lr = lr * 0.9
            adjust_learning_rate(optimizer, lr)
            print(f'lr updated to {lr}')

        tq.set_description(f'Epoch {epoch}, lr {lr}')
        losses = []
        tl = train_loader
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
        try:
            mean_loss = 0

            for i, (inputs, targets) in enumerate(
                    tl):  #enumerate() turns tl into index, ele_of_tl
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()

                feats, outputs = model(inputs)
                outputs = outputs.squeeze()
                feats = feats.squeeze()

                loss1 = softmax_loss(outputs, targets)
                loss2 = TripletLossV1(margin=0.5)(feats, targets)

                loss = 0.5 * loss1 + loss2

                batch_size = inputs.size(0)

                (batch_size * loss).backward()
                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')

                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)

            write_event(log, step, loss=mean_loss)
            tq.close()
            print('saving')
            save(epoch + 1)
            print('validation')
            valid_metrics = validation(model, criterion, valid_loader,
                                       use_cuda)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                shutil.copy(str(model_path),
                            str(run_root) + '/model_loss_best.pt')

        except KeyboardInterrupt:
            tq.close()
            # print('Ctrl+C, saving snapshot')
            # save(epoch)
            # print('done.')

            return False
    return True
Пример #10
0
def train(args,
          model: nn.Module,
          criterion,
          *,
          train_loader,
          valid_loader,
          validation,
          init_optimizer,
          fold=None,
          save_predictions=None,
          n_epochs=configs.EPOCHS):
    ##          Log

    ##          Checkpoint path
    checkpoint_path = configs.CHECKPOINT_PATH
    model_checkpoint = checkpoint_path / 'ternaus_{fold}.pt'.format(fold=fold)
    best_model_checkpoint = checkpoint_path / 'ternaus_best_{fold}.pt'.format(
        fold=fold)

    ##          Start training
    if model_checkpoint.exists():
        state = torch.load(str(model_checkpoint))
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
        model.load_state_dict(state['model'])
        print("Restored model, epoch {}, step {:,}".format(epoch, step))
    else:
        epoch = 1
        step = 0
        best_valid_loss = float('inf')

    ##          Save checkpoint
    save = lambda ep: torch.save(
        {
            'model': model.state_dict(),
            'epoch': ep,
            'step': step,
            'best_valid_loss': best_valid_loss
        }, str(model_checkpoint))

    ##      Initializer epoch
    report_each = 10
    save_prediction_each = report_each * 20
    valid_losses = []

    ###         Training
    for epoch in range(epoch, n_epochs + 1):
        lr = utils.cyclic_lr(epoch)
        optimizer = init_optimizer(lr)

        model.train()
        random.seed()
        tq = tqdm(total=(len(train_loader) * configs.BATCH_SIZE))
        tq.set_description("Epoch {}, lr: {}".format(epoch, lr))

        losses = []
        # tl = train_loader
        # if args.epoch_size:
        #     tl = islice(tl, args.epoch_size // args.batch_size)

        mean_loss = 0
        for i, (inputs, targets) in enumerate(train_loader):
            ##
            inputs, targets = utils.variable(inputs), utils.variable(targets)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            ##
            optimizer.zero_grad()
            batch_size = inputs.size(0)
            step += 1
            tq.update(batch_size)
            losses.append(loss.data[0])
            mean_loss = np.mean(losses[-report_each:])
            tq.set_postfix(loss='{:.5f}'.format(mean_loss))

            (batch_size * loss).backward()
            optimizer.step()

            if i and i % report_each == 0:
                utils.write_event(log, step, loss=mean_loss)
                if save_predictions and i % save_prediction_each == 0:
                    p_i = (i // save_prediction_each) % 5
                    save_predictions(root, p_i, inputs, targets, outputs)
        utils.write_event(log, step, loss=mean_loss)
        tq.close()
        save(epoch + 1)
        valid_metrics = validation(model, criterion, valid_loader)
        write_event(log, step, **valid_metrics)
        valid_loss = valid_metrics['valid_loss']
        valid_losses.append(valid_loss)
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            shutil.copy(str(model_path), str(best_model_path))
def train(args,
          model: nn.Module,
          criterion,
          *,
          params,
          train_loader,
          valid_loader,
          init_optimizer,
          use_cuda,
          n_epochs=None,
          patience=2,
          max_lr_changes=3) -> bool:
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)
    optimizer = init_optimizer(params, lr)

    run_root = Path(args.run_root)

    model_path = Path(str(run_root) + '/' + 'model.pt')

    if model_path.exists():
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
        best_f2 = state['best_f2']
    else:
        epoch = 1
        step = 0
        best_valid_loss = float('inf')
        best_f2 = 0

    lr_changes = 0

    save = lambda ep: torch.save(
        {
            'model': model.state_dict(),
            'epoch': ep,
            'step': step,
            'best_valid_loss': best_valid_loss,
            'best_f2': best_f2
        }, str(model_path))

    report_each = 100
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    valid_f2s = []
    lr_reset_epoch = epoch
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(
            total=(args.epoch_size or len(train_loader) * args.batch_size))
        tq.set_description(f'Epoch {epoch}, lr {lr}')
        losses = []
        tl = train_loader
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
        try:
            mean_loss = 0
            for i, (inputs, targets) in enumerate(tl):
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets_a, targets_b, lam = mixup_data(
                    inputs, targets, 1, use_cuda)
                inputs, targets_a, targets_b = Variable(inputs), Variable(
                    targets_a), Variable(targets_b)
                outputs = model(inputs)
                loss_func = mixup_criterion(targets_a, targets_b, lam)
                loss = loss_func(criterion, outputs)
                loss = _reduce_loss(loss)

                batch_size = inputs.size(0)
                (batch_size * loss).backward()
                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')
                # if i and i % report_each == 0:
                #     write_event(log, step, loss=mean_loss)
            write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1)
            valid_metrics = validation(model, criterion, valid_loader,
                                       use_cuda)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_f2 = valid_metrics['valid_f2_th_0.10']
            valid_f2s.append(valid_f2)
            valid_losses.append(valid_loss)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                #shutil.copy(str(model_path), str(run_root) + '/model_loss_' + f'{valid_loss:.4f}' + '.pt')

            if valid_f2 > best_f2:
                best_f2 = valid_f2
                shutil.copy(
                    str(model_path),
                    str(run_root) + '/model_f2_' + f'{valid_f2:.4f}' + '.pt')


#             if epoch == 7:
#                 lr = 1e-4
#                 print(f'lr updated to {lr}')
#                 optimizer = init_optimizer(params, lr)
#             if epoch == 8:
#                 lr = 1e-5
#                 optimizer = init_optimizer(params, lr)
#                 print(f'lr updated to {lr}')
        except KeyboardInterrupt:
            tq.close()
            #             print('Ctrl+C, saving snapshot')
            #             save(epoch)
            #             print('done.')
            return False
    return True
Пример #12
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', type=float, default=1)
    arg('--root', type=str, default='runs/debug', help='checkpoint root')
    arg('--image-path', type=str, default='data', help='image path')
    arg('--batch-size', type=int, default=2)
    arg('--n-epochs', type=int, default=100)
    arg('--optimizer', type=str, default='Adam', help='Adam or SGD')
    arg('--lr', type=float, default=0.001)
    arg('--workers', type=int, default=10)
    arg('--model',
        type=str,
        default='UNet16',
        choices=[
            'UNet', 'UNet11', 'UNet16', 'LinkNet34', 'FCDenseNet57',
            'FCDenseNet67', 'FCDenseNet103'
        ])
    arg('--model-weight', type=str, default=None)
    arg('--resume-path', type=str, default=None)
    arg('--attribute',
        type=str,
        default='all',
        choices=[
            'pigment_network', 'negative_network', 'streaks',
            'milia_like_cyst', 'globules', 'all'
        ])
    args = parser.parse_args()

    ## folder for checkpoint
    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    image_path = args.image_path

    #print(args)
    if args.attribute == 'all':
        num_classes = 5
    else:
        num_classes = 1
    args.num_classes = num_classes
    ### save initial parameters
    print('--' * 10)
    print(args)
    print('--' * 10)
    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    ## load pretrained model
    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'FCDenseNet103':
        model = FCDenseNet103(num_classes=num_classes)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    ## multiple GPUs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    ## load pretrained model
    if args.model_weight is not None:
        state = torch.load(args.model_weight)
        #epoch = state['epoch']
        #step = state['step']
        model.load_state_dict(state['model'])
        print('--' * 10)
        print('Load pretrained model', args.model_weight)
        #print('Restored model, epoch {}, step {:,}'.format(epoch, step))
        print('--' * 10)
        ## replace the last layer
        ## although the model and pre-trained weight have differernt size (the last layer is different)
        ## pytorch can still load the weight
        ## I found that the weight for one layer just duplicated for all layers
        ## therefore, the following code is not necessary
        # if args.attribute == 'all':
        #     model = list(model.children())[0]
        #     num_filters = 32
        #     model.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
        #     print('--' * 10)
        #     print('Load pretrained model and replace the last layer', args.model_weight, num_classes)
        #     print('--' * 10)
        #     if torch.cuda.device_count() > 1:
        #         model = nn.DataParallel(model)
        #     model.to(device)

    ## model summary
    print_model_summay(model)

    ## define loss
    loss_fn = LossBinary(jaccard_weight=args.jaccard_weight)

    ## It enables benchmark mode in cudnn.
    ## benchmark mode is good whenever your input sizes for your network do not vary. This way, cudnn will look for the
    ## optimal set of algorithms for that particular configuration (which takes some time). This usually leads to faster runtime.
    ## But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears,
    ## possibly leading to worse runtime performances.
    cudnn.benchmark = True

    ## get train_test_id
    train_test_id = get_split()

    ## train vs. val
    print('--' * 10)
    print('num train = {}, num_val = {}'.format(
        (train_test_id['Split'] == 'train').sum(),
        (train_test_id['Split'] != 'train').sum()))
    print('--' * 10)

    train_transform = DualCompose(
        [HorizontalFlip(),
         VerticalFlip(),
         ImageOnly(Normalize())])

    val_transform = DualCompose([ImageOnly(Normalize())])

    ## define data loader
    train_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=True,
                               shuffle=True,
                               transform=train_transform)
    valid_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=False,
                               shuffle=True,
                               transform=val_transform)

    if True:
        print('--' * 10)
        print('check data')
        train_image, train_mask, train_mask_ind = next(iter(train_loader))
        print('train_image.shape', train_image.shape)
        print('train_mask.shape', train_mask.shape)
        print('train_mask_ind.shape', train_mask_ind.shape)
        print('train_image.min', train_image.min().item())
        print('train_image.max', train_image.max().item())
        print('train_mask.min', train_mask.min().item())
        print('train_mask.max', train_mask.max().item())
        print('train_mask_ind.min', train_mask_ind.min().item())
        print('train_mask_ind.max', train_mask_ind.max().item())
        print('--' * 10)

    valid_fn = validation_binary

    ###########
    ## optimizer
    if args.optimizer == 'Adam':
        optimizer = Adam(model.parameters(), lr=args.lr)
    elif args.optimizer == 'SGD':
        optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9)

    ## loss
    criterion = loss_fn
    ## change LR
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=0.8,
                                  patience=5,
                                  verbose=True)

    ##########
    ## load previous model status
    previous_valid_loss = 10
    model_path = root / 'model.pt'
    if args.resume_path is not None and model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        epoch = 1
        step = 0
        try:
            previous_valid_loss = state['valid_loss']
        except:
            previous_valid_loss = 10
        print('--' * 10)
        print('Restored previous model, epoch {}, step {:,}'.format(
            epoch, step))
        print('--' * 10)
    else:
        epoch = 1
        step = 0

    #########
    ## start training
    log = root.joinpath('train.log').open('at', encoding='utf8')
    writer = SummaryWriter()
    meter = AllInOneMeter()
    #if previous_valid_loss = 10000
    print('Start training')
    print_model_summay(model)
    previous_valid_jaccard = 0
    for epoch in range(epoch, args.n_epochs + 1):
        model.train()
        random.seed()
        #jaccard = []
        start_time = time.time()
        meter.reset()
        w1 = 1.0
        w2 = 0.5
        w3 = 0.5
        try:
            train_loss = 0
            valid_loss = 0
            # if epoch == 1:
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     #set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            # elif epoch == 5:
            #     w1 = 1.0
            #     w2 = 0.0
            #     w3 = 0.5
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     # set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            #elif epoch == 3:
            #     set_train_layers(model, train_layer_names=['module.dec5.block.0.conv.weight','module.dec5.block.0.conv.bias',
            #                                                'module.dec5.block.1.weight','module.dec5.block.1.bias',
            #                                                'module.dec4.block.0.conv.weight','module.dec4.block.0.conv.bias',
            #                                                'module.dec4.block.1.weight','module.dec4.block.1.bias',
            #                                                'module.dec3.block.0.conv.weight','module.dec3.block.0.conv.bias',
            #                                                'module.dec3.block.1.weight','module.dec3.block.1.bias',
            #                                                'module.dec2.block.0.conv.weight','module.dec2.block.0.conv.bias',
            #                                                'module.dec2.block.1.weight','module.dec2.block.1.bias',
            #                                                'module.dec1.conv.weight','module.dec1.conv.bias',
            #                                                'module.final.weight','module.final.bias'])
            #     print_model_summa zvgf    t5y(model)
            # elif epoch == 50:
            #     set_freeze_layers(model, freeze_layer_names=None)
            #     print_model_summay(model)
            for i, (train_image, train_mask,
                    train_mask_ind) in enumerate(train_loader):
                # inputs, targets = variable(inputs), variable(targets)

                train_image = train_image.permute(0, 3, 1, 2)
                train_mask = train_mask.permute(0, 3, 1, 2)
                train_image = train_image.to(device)
                train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                train_mask_ind = train_mask_ind.to(device).type(
                    torch.cuda.FloatTensor)
                # if args.problem_type == 'binary':
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                # else:
                #     #train_mask = train_mask.to(device).type(torch.cuda.LongTensor)
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)

                outputs, outputs_mask_ind1, outputs_mask_ind2 = model(
                    train_image)
                #print(outputs.size())
                #print(outputs_mask_ind1.size())
                #print(outputs_mask_ind2.size())
                ### note that the last layer in the model is defined differently
                # if args.problem_type == 'binary':
                #     train_prob = F.sigmoid(outputs)
                #     loss = criterion(outputs, train_mask)
                # else:
                #     #train_prob = outputs
                #     train_prob = F.sigmoid(outputs)
                #     loss = torch.tensor(0).type(train_mask.type())
                #     for feat_inx in range(train_mask.shape[1]):
                #         loss += criterion(outputs, train_mask)
                train_prob = F.sigmoid(outputs)
                train_mask_ind_prob1 = F.sigmoid(outputs_mask_ind1)
                train_mask_ind_prob2 = F.sigmoid(outputs_mask_ind2)
                loss1 = criterion(outputs, train_mask)
                #loss1 = F.binary_cross_entropy_with_logits(outputs, train_mask)
                #loss2 = nn.BCEWithLogitsLoss()(outputs_mask_ind1, train_mask_ind)
                #print(train_mask_ind.size())
                #weight = torch.ones_like(train_mask_ind)
                #weight[:, 0] = weight[:, 0] * 1
                #weight[:, 1] = weight[:, 1] * 14
                #weight[:, 2] = weight[:, 2] * 14
                #weight[:, 3] = weight[:, 3] * 4
                #weight[:, 4] = weight[:, 4] * 4
                #weight = weight * train_mask_ind + 1
                #weight = weight.to(device).type(torch.cuda.FloatTensor)
                loss2 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind1, train_mask_ind)
                loss3 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind2, train_mask_ind)
                #loss3 = criterion(outputs_mask_ind2, train_mask_ind)
                loss = loss1 * w1 + loss2 * w2 + loss3 * w3
                #print(loss1.item(), loss2.item(), loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                step += 1
                #jaccard += [get_jaccard(train_mask, (train_prob > 0).float()).item()]
                meter.add(train_prob, train_mask, train_mask_ind_prob1,
                          train_mask_ind_prob2, train_mask_ind, loss1.item(),
                          loss2.item(), loss3.item(), loss.item())
                # print(train_mask.data.shape)
                # print(train_mask.data.sum(dim=-2).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).sum(dim=0).shape)
                # intersection = train_mask.data.sum(dim=-2).sum(dim=-1)
                # print(intersection.shape)
                # print(intersection.dtype)
                # print(train_mask.data.shape[0])
                #torch.zeros([2, 4], dtype=torch.float32)
            #########################
            ## at the end of each epoch, evualte the metrics
            epoch_time = time.time() - start_time
            train_metrics = meter.value()
            train_metrics['epoch_time'] = epoch_time
            train_metrics['image'] = train_image.data
            train_metrics['mask'] = train_mask.data
            train_metrics['prob'] = train_prob.data

            #train_jaccard = np.mean(jaccard)
            #train_auc = str(round(mtr1.value()[0],2))+' '+str(round(mtr2.value()[0],2))+' '+str(round(mtr3.value()[0],2))+' '+str(round(mtr4.value()[0],2))+' '+str(round(mtr5.value()[0],2))
            valid_metrics = valid_fn(model, criterion, valid_loader, device,
                                     num_classes)
            ##############
            ## write events
            write_event(log,
                        step,
                        epoch=epoch,
                        train_metrics=train_metrics,
                        valid_metrics=valid_metrics)
            #save_weights(model, model_path, epoch + 1, step)
            #########################
            ## tensorboard
            write_tensorboard(writer,
                              model,
                              epoch,
                              train_metrics=train_metrics,
                              valid_metrics=valid_metrics)
            #########################
            ## save the best model
            valid_loss = valid_metrics['loss1']
            valid_jaccard = valid_metrics['jaccard']
            if valid_loss < previous_valid_loss:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_loss = valid_loss
                print('Save best model by loss')
            if valid_jaccard > previous_valid_jaccard:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_jaccard = valid_jaccard
                print('Save best model by jaccard')
            #########################
            ## change learning rate
            scheduler.step(valid_metrics['loss1'])

        except KeyboardInterrupt:
            # print('--' * 10)
            # print('Ctrl+C, saving snapshot')
            # save_weights(model, model_path, epoch, step)
            # print('done.')
            # print('--' * 10)
            writer.close()
            #return
    writer.close()
def train(args, model: nn.Module, criterion, *, params,
          train_loader, valid_loader, init_optimizer, use_cuda,
          n_epochs=None, patience=2, max_lr_changes=3) -> bool:
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)  # in case params is not a list
    # add params into optimizer
    optimizer = init_optimizer(params, lr)

    # model load/save path
    run_root = Path(args.run_root)

    model_path = Path(str(run_root) + '/' + 'model.pt')

    if model_path.exists():
        print('loading existing weights from model.pt')
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
        best_f1 = state['best_f1']
    else:
        epoch = 1
        step = 0
        best_valid_loss = 0.0  # float('inf')
        best_f1 = 0

    lr_changes = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': best_valid_loss,
        'best_f1': best_f1
    }, str(model_path))

    save_where = lambda ep, svpath: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': best_valid_loss,
        'best_f1': best_f1
    }, str(svpath))

    report_each = 100
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    valid_f1s = []
    lr_reset_epoch = epoch

    # epoch loop
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(total=(args.epoch_size or
                              len(train_loader) * args.batch_size))

        if epoch >= 20 and epoch % 2 == 0:
            lr = lr * 0.9
            adjust_learning_rate(optimizer, lr)
            print('lr updated to %0.8f' % lr)

        tq.set_description('Epoch %d, lr %0.8f' % (epoch, lr))
        losses = []
        tl = train_loader
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
        try:
            mean_loss = 0

            for i, batch_dat in enumerate(tl):  # enumerate() turns tl into index, ele_of_tl
                featCompPos = batch_dat['feat_comp_pos']
                featCompNeg = batch_dat['feat_comp_neg']
                featRegion = batch_dat['feat_comp_region']
                featLoc = batch_dat['feat_loc']

                if use_cuda:
                    featCompPos, featCompNeg, featRegion, featLoc = featCompPos.cuda(), featCompNeg.cuda(), featRegion.cuda(), featLoc.cuda()

                # common_feat_comp, common_feat_loc, feat_comp_loc, outputs = model(feat_comp=featComp, feat_loc=featLoc)
                if args.model == 'location_recommend_region_model_v1':
                    model_output_pos = model(feat_comp=featCompPos, feat_K_comp=featRegion)
                    model_output_neg = model(feat_comp=featCompNeg, feat_K_comp=featRegion)
                elif args.model == 'location_recommend_region_model_v0':
                    model_output_pos = model(feat_comp=featCompPos, feat_loc=featLoc)
                    model_output_neg = model(feat_comp=featCompNeg, feat_loc=featLoc)
                else:
                    model_output_pos = model(feat_comp=featCompPos, feat_K_comp=featRegion, feat_loc=featLoc)
                    model_output_neg = model(feat_comp=featCompNeg, feat_K_comp=featRegion, feat_loc=featLoc)

                # outputs = torch.cat( [ model_output_pos['outputs'], model_output_neg['outputs'] ], dim = 0)

                nP, nN = model_output_pos['outputs'].shape[0], model_output_neg['outputs'].shape[0]
                target_pos = torch.ones((nP, 1), dtype=torch.long)
                target_neg = torch.zeros((nN, 1), dtype=torch.long)
                # targets = torch.cat( [ target_pos, target_neg ], dim = 0)

                if use_cuda:
                    # targets = targets.cuda()
                    target_pos = target_pos.cuda()
                    target_neg = target_neg.cuda()

                lossP = softmax_loss(model_output_pos['outputs'], target_pos)
                lossN = softmax_loss(model_output_neg['outputs'], target_neg)
                loss = lossP + 0.8 * lossN

                if args.model in ['location_recommend_region_model_v4', 'location_recommend_region_model_v5']:
                    lW = 0.1
                    loss2 = l2_loss(model_output_pos['feat_loc_pred'], featLoc)  # Neg Pos share the same location
                    loss = (1 - lW) * loss + lW * loss2

                lossType = 'softmax'

                batch_size = nP + nN

                (batch_size * loss).backward()
                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(1 * args.batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')

                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)

            write_event(log, step, loss=mean_loss)
            tq.close()
            print('saving')
            save(epoch + 1)
            print('validation')
            valid_metrics = validation(args, model, criterion, valid_loader, use_cuda, lossType=lossType)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_top1 = valid_metrics['valid_top1']
            valid_roc = valid_metrics['auc']
            valid_losses.append(valid_loss)

            # tricky
            valid_loss = valid_roc
            if valid_loss > best_valid_loss:  # roc:bigger is better
                best_valid_loss = valid_loss
                shutil.copy(str(model_path), str(run_root) + '/model_loss_best.pt')

        except KeyboardInterrupt:
            tq.close()
            # print('Ctrl+C, saving snapshot')
            # save(epoch)
            # print('done.')

            return False
    return True
Пример #14
0
                        result = cv2.cvtColor(
                            cv2.hconcat(
                                (np.uint8(255 * (color * 0.5 + 0.5)),
                                 np.uint8(255 * (pred_color * 0.5 + 0.5)))),
                            cv2.COLOR_BGR2RGB)
                        cv2.imwrite(
                            str(root / 'validation_{counter}.png'.format(
                                counter=counter)), result)
                        counter += 1

                ## Save both models
                best_mean_rec_loss = mean_rec_loss
                save(epoch, netD, D_model_path, best_mean_rec_loss)
                save(epoch, netG, G_model_path, best_mean_rec_loss)
                print("Finding better model in terms of validation loss: {}".
                      format(best_mean_rec_loss))

        utils.write_event(log, step, Rec_error=mean_recover_loss)
        utils.write_event(log, step, Dloss=mean_D_loss)
        utils.write_event(log, step, Gloss=mean_G_loss)
        tq.close()
    except KeyboardInterrupt:
        cv2.destroyAllWindows()
        tq.close()
        print('Ctrl+C, saving snapshot')
        # save(epoch, netD, D_model_path)
        # save(epoch, netG, G_model_path)
        print('done.')
        exit()
Пример #15
0
def train(args, model: nn.Module, criterion, *, params,
          train_loader, valid_loader, init_optimizer, use_cuda,
          n_epochs=None, patience=2, max_lr_changes=2,finetuning=False) -> bool:
    
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)
    optimizer = init_optimizer(params, lr)

    run_root = Path(args.run_root)
    model_path = run_root / 'model.pt'
    best_model_path = run_root / 'best-model.pt'
    
    if model_path.exists():
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
        best_valid_loss = state['best_valid_loss']
        
    if best_model_path.exists() and finetuning:
        state = load_model(model,best_model_path)
        epoch = 1
        step = 0
     #   epoch = state['epoch']
     #   step = state['step']
        best_valid_loss = state['best_valid_loss']
    
    else:
        epoch = 1
        step = 0
        best_valid_loss = float('inf')
        
    lr_changes = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': best_valid_loss
    }, str(model_path))

    report_each = 10
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    lr_reset_epoch = epoch
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(total=(args.epoch_size or
                              len(train_loader) * args.batch_size))
        tq.set_description(f'Epoch {epoch}, lr {lr}')
        losses = []
        tl = train_loader
     #   from IPython.core.debugger import Pdb; Pdb().set_trace()
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
            
        try:
            mean_loss = 0
          #  Pdb().set_trace()
            for i, (inputs, targets) in enumerate(tl):        
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)#_reduce_loss(criterion(outputs, targets))
                batch_size = inputs.size(0)
                (batch_size * loss).backward()
                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')
                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)
          #  Pdb().set_trace()
            write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1)
            valid_metrics = validation(model, criterion, valid_loader, use_cuda)
            
         #   Pdb().set_trace()
            write_event(log, step, **valid_metrics)
            
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)
            
        #    from IPython.core.debugger import Pdb; Pdb().set_trace()
            
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                shutil.copy(str(model_path), str(best_model_path))
            elif (patience and epoch - lr_reset_epoch > patience and
                  min(valid_losses[-patience:]) > best_valid_loss):
                # "patience" epochs without improvement
                lr_changes +=1
                if lr_changes > max_lr_changes:
                    break
                lr /= 5
                print(f'lr updated to {lr}')
                lr_reset_epoch = epoch
                optimizer = init_optimizer(params, lr)
        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save(epoch)
            print('done.')
            return False
    return True
Пример #16
0
                                         np.uint8(255 * (mask * 0.5 + 0.5)))),
                            cv2.COLOR_BGR2RGB)
                        cv2.imwrite(
                            str(root / 'validation_{counter}.png'.format(
                                counter=counter)), result)
                        counter += 1

                ## Save both models
                best_mean_dice_coeffs = mean_dice_coeffs
                save(epoch, netD, D_model_path, best_mean_dice_coeffs)
                save(epoch, netG, G_model_path, best_mean_dice_coeffs)
                print("Finding better model in terms of validation loss: {}".
                      format(best_mean_dice_coeffs))

        else:
            tq.set_postfix(loss=' D={:.5f}, G={:.5f}'.format(
                np.mean(D_losses), np.mean(G_losses)))

        utils.write_event(log, step, Dice_coeff=mean_dice_coeffs)
        utils.write_event(log, step, Dloss=mean_D_loss)
        utils.write_event(log, step, Gloss=mean_G_loss)
        tq.close()

    except KeyboardInterrupt:
        cv2.destroyAllWindows()
        tq.close()
        print('Ctrl+C, saving snapshot')
        # save(epoch, netD, D_model_path)
        # save(epoch, netG, G_model_path)
        print('done.')
        exit()