示例#1
0
def train(images_tr, labels_tr, images_val, labels_val, model, model_cat,
          callbacks, train_config, loggers):

    # Create batch generators
    image_gen_tr = get_generator(images_tr, **train_config.daug.daug_params_tr)
    batch_gen_tr = generate_batches(
        image_gen_tr,
        images_tr,
        labels_tr,
        train_config.train.batch_size.gen_tr,
        aug_per_im=train_config.daug.aug_per_img_tr,
        shuffle=True,
        seed=train_config.seeds.batch_shuffle,
        n_inv_layers=train_config.optimizer.n_inv_layers)
    image_gen_val = get_generator(images_val,
                                  **train_config.daug.daug_params_val)
    batch_gen_val = generate_batches(
        image_gen_val,
        images_val,
        labels_val,
        train_config.train.batch_size.gen_val,
        aug_per_im=train_config.daug.aug_per_img_val,
        shuffle=False,
        n_inv_layers=train_config.optimizer.n_inv_layers)
    if FLAGS.no_val:
        batch_gen_val = None

    # Train model
    if FLAGS.no_fit_generator:

        metrics_names_val = [
            'val_{}'.format(metric_name) for metric_name in model.metrics_names
        ]
        no_mean_metrics_progbar = True
        #         no_mean_metrics_progbar = False

        for callback in callbacks.values():
            callback.set_model(model)
            callback.on_train_begin()

        for epoch in range(train_config.train.epochs):

            print('Epoch {}/{}'.format(epoch + 1, train_config.train.epochs))

            # Progress bar
            #         progbar = Progbar(target=train_config.train.batches_per_epoch_tr,
            #                           stateful_metrics=None)
            progbar = Progbar(target=train_config.train.batches_per_epoch_tr)

            for callback in callbacks.values():
                callback.on_epoch_begin(epoch)

            for batch_idx in range(train_config.train.batches_per_epoch_tr):

                for callback in callbacks.values():
                    callback.on_batch_begin(batch_idx)

                # Train
                batch = next(batch_gen_tr)
                debug = False

                # Log
                if loggers:
                    for logger in loggers:
                        logger.get_activations()

                # debug
                if debug:
                    preds = model.predict_on_batch(batch[0])
                    metrics = model.test_on_batch(batch[0], batch[1])
                    metrics_daug = metrics[model.metrics_names.index(
                        'daug_inv5_loss')]
                    metrics_class = metrics[model.metrics_names.index(
                        'class_inv5_loss')]
                    daug_true = batch[1][1]
                    daug_true_rel = daug_true[:, :, 0]
                    daug_true_all = daug_true[:, :, 1]
                    class_true = batch[1][2]
                    class_true_rel = class_true[:, :, 0]
                    class_true_all = class_true[:, :, 1]
                    pred_daug = preds[model.output_names.index(
                        'daug_inv5')][:, :, 0]
                    pred_class = preds[model.output_names.index(
                        'class_inv5')][:, :, 0]

                    num_daug = np.sum(daug_true_rel * pred_daug) / \
                               np.sum(daug_true_rel)
                    den_daug = np.sum(daug_true_all * pred_daug) / \
                               np.sum(daug_true_all)
                    loss_daug = num_daug / den_daug
                    num_class = np.sum(class_true_rel * pred_class) / \
                                np.sum(class_true_rel)
                    den_class = np.sum(class_true_all * pred_class) / \
                                np.sum(class_true_all)
                    loss_class = num_class / den_class
                # debug
                metrics = model.train_on_batch(batch[0], batch[1])
                if model_cat:
                    output_inv = model.predict_on_batch(batch[0])[0]
                    metrics_cat = model_cat.train_on_batch(
                        output_inv, batch[1][0])
                    metrics_names_cat = model_cat.metrics_names[:]
                else:
                    metrics_cat = []
                    metrics_names_cat = []

                # Progress bar
                if batch_idx + 1 < progbar.target:
                    metrics_progbar = sel_metrics(
                        model.metrics_names,
                        metrics,
                        no_mean_metrics_progbar,
                        metrics_cat=metrics_names_cat)
                    metrics_progbar.extend(zip(metrics_names_cat, metrics_cat))
                    progbar.update(current=batch_idx + 1,
                                   values=metrics_progbar)

                # Log
                if loggers:
                    metrics_log = sel_metrics(model.metrics_names,
                                              metrics,
                                              no_mean=False,
                                              metrics_cat=metrics_names_cat)
                    metrics_log.extend(zip(metrics_names_cat, metrics_cat))
                    for logger in loggers:
                        logger.log(metrics_log)

                for callback in callbacks.values():
                    callback.on_batch_end(batch_idx)

            # Validation
            metrics_val = np.zeros(len(metrics))
            for batch_idx in range(train_config.train.batches_per_epoch_val):

                batch = next(batch_gen_val)
                metrics_val_batch = model.test_on_batch(batch[0], batch[1])

                for idx, metric in enumerate(metrics_val_batch):
                    metrics_val[idx] += metric

            metrics_val /= train_config.train.batches_per_epoch_val
            metrics_val = metrics_val.tolist()

            # Progress bar
            metrics_progbar = sel_metrics(
                model.metrics_names + metrics_names_val,
                metrics + metrics_val,
                no_mean_metrics_progbar,
                no_val_daug=train_config.daug.aug_per_img_val == 1)
            progbar.add(1, values=metrics_progbar)

            # Tensorboard
            metrics_names_tensorboard = list(progbar.sum_values.keys())
            metrics_tensorboard = [
                metric[0] / float(metric[1])
                for metric in progbar.sum_values.values()
            ]
            for metric_name, metric in zip(
                    model.metrics_names + metrics_names_val,
                    metrics + metrics_val):
                if metric_name not in metrics_names_tensorboard:
                    metrics_names_tensorboard.append(metric_name)
                    metrics_tensorboard.append(metric)
            metrics_tensorboard = sel_metrics(
                metrics_names_tensorboard,
                metrics_tensorboard,
                no_mean=False,
                no_val_daug=train_config.daug.aug_per_img_val > 1,
                metrics_cat=[])
            metrics_tensorboard = [
                list(item) for item in zip(*metrics_tensorboard)
            ]
            write_tensorboard(callbacks['tensorboard'], metrics_tensorboard[0],
                              metrics_tensorboard[1], epoch)

            for callback in callbacks.values():
                callback.on_epoch_end(epoch)

        history = None
    else:
        history = model.fit_generator(
            generator=batch_gen_tr,
            steps_per_epoch=train_config.train.batches_per_epoch_tr,
            epochs=train_config.train.epochs,
            validation_data=batch_gen_val,
            validation_steps=train_config.train.batches_per_epoch_val,
            initial_epoch=train_config.train.initial_epoch,
            max_queue_size=train_config.data.queue_size,
            callbacks=list(callbacks.values()))

    if loggers:
        for logger in loggers:
            logger.close()

    return history, model
示例#2
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', type=float, default=1)
    arg('--root', type=str, default='runs/debug', help='checkpoint root')
    arg('--image-path', type=str, default='data', help='image path')
    arg('--batch-size', type=int, default=2)
    arg('--n-epochs', type=int, default=100)
    arg('--optimizer', type=str, default='Adam', help='Adam or SGD')
    arg('--lr', type=float, default=0.001)
    arg('--workers', type=int, default=10)
    arg('--model',
        type=str,
        default='UNet16',
        choices=[
            'UNet', 'UNet11', 'UNet16', 'LinkNet34', 'FCDenseNet57',
            'FCDenseNet67', 'FCDenseNet103'
        ])
    arg('--model-weight', type=str, default=None)
    arg('--resume-path', type=str, default=None)
    arg('--attribute',
        type=str,
        default='all',
        choices=[
            'pigment_network', 'negative_network', 'streaks',
            'milia_like_cyst', 'globules', 'all'
        ])
    args = parser.parse_args()

    ## folder for checkpoint
    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    image_path = args.image_path

    #print(args)
    if args.attribute == 'all':
        num_classes = 5
    else:
        num_classes = 1
    args.num_classes = num_classes
    ### save initial parameters
    print('--' * 10)
    print(args)
    print('--' * 10)
    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    ## load pretrained model
    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'FCDenseNet103':
        model = FCDenseNet103(num_classes=num_classes)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    ## multiple GPUs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    ## load pretrained model
    if args.model_weight is not None:
        state = torch.load(args.model_weight)
        #epoch = state['epoch']
        #step = state['step']
        model.load_state_dict(state['model'])
        print('--' * 10)
        print('Load pretrained model', args.model_weight)
        #print('Restored model, epoch {}, step {:,}'.format(epoch, step))
        print('--' * 10)
        ## replace the last layer
        ## although the model and pre-trained weight have differernt size (the last layer is different)
        ## pytorch can still load the weight
        ## I found that the weight for one layer just duplicated for all layers
        ## therefore, the following code is not necessary
        # if args.attribute == 'all':
        #     model = list(model.children())[0]
        #     num_filters = 32
        #     model.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
        #     print('--' * 10)
        #     print('Load pretrained model and replace the last layer', args.model_weight, num_classes)
        #     print('--' * 10)
        #     if torch.cuda.device_count() > 1:
        #         model = nn.DataParallel(model)
        #     model.to(device)

    ## model summary
    print_model_summay(model)

    ## define loss
    loss_fn = LossBinary(jaccard_weight=args.jaccard_weight)

    ## It enables benchmark mode in cudnn.
    ## benchmark mode is good whenever your input sizes for your network do not vary. This way, cudnn will look for the
    ## optimal set of algorithms for that particular configuration (which takes some time). This usually leads to faster runtime.
    ## But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears,
    ## possibly leading to worse runtime performances.
    cudnn.benchmark = True

    ## get train_test_id
    train_test_id = get_split()

    ## train vs. val
    print('--' * 10)
    print('num train = {}, num_val = {}'.format(
        (train_test_id['Split'] == 'train').sum(),
        (train_test_id['Split'] != 'train').sum()))
    print('--' * 10)

    train_transform = DualCompose(
        [HorizontalFlip(),
         VerticalFlip(),
         ImageOnly(Normalize())])

    val_transform = DualCompose([ImageOnly(Normalize())])

    ## define data loader
    train_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=True,
                               shuffle=True,
                               transform=train_transform)
    valid_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=False,
                               shuffle=True,
                               transform=val_transform)

    if True:
        print('--' * 10)
        print('check data')
        train_image, train_mask, train_mask_ind = next(iter(train_loader))
        print('train_image.shape', train_image.shape)
        print('train_mask.shape', train_mask.shape)
        print('train_mask_ind.shape', train_mask_ind.shape)
        print('train_image.min', train_image.min().item())
        print('train_image.max', train_image.max().item())
        print('train_mask.min', train_mask.min().item())
        print('train_mask.max', train_mask.max().item())
        print('train_mask_ind.min', train_mask_ind.min().item())
        print('train_mask_ind.max', train_mask_ind.max().item())
        print('--' * 10)

    valid_fn = validation_binary

    ###########
    ## optimizer
    if args.optimizer == 'Adam':
        optimizer = Adam(model.parameters(), lr=args.lr)
    elif args.optimizer == 'SGD':
        optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9)

    ## loss
    criterion = loss_fn
    ## change LR
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=0.8,
                                  patience=5,
                                  verbose=True)

    ##########
    ## load previous model status
    previous_valid_loss = 10
    model_path = root / 'model.pt'
    if args.resume_path is not None and model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        epoch = 1
        step = 0
        try:
            previous_valid_loss = state['valid_loss']
        except:
            previous_valid_loss = 10
        print('--' * 10)
        print('Restored previous model, epoch {}, step {:,}'.format(
            epoch, step))
        print('--' * 10)
    else:
        epoch = 1
        step = 0

    #########
    ## start training
    log = root.joinpath('train.log').open('at', encoding='utf8')
    writer = SummaryWriter()
    meter = AllInOneMeter()
    #if previous_valid_loss = 10000
    print('Start training')
    print_model_summay(model)
    previous_valid_jaccard = 0
    for epoch in range(epoch, args.n_epochs + 1):
        model.train()
        random.seed()
        #jaccard = []
        start_time = time.time()
        meter.reset()
        w1 = 1.0
        w2 = 0.5
        w3 = 0.5
        try:
            train_loss = 0
            valid_loss = 0
            # if epoch == 1:
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     #set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            # elif epoch == 5:
            #     w1 = 1.0
            #     w2 = 0.0
            #     w3 = 0.5
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     # set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            #elif epoch == 3:
            #     set_train_layers(model, train_layer_names=['module.dec5.block.0.conv.weight','module.dec5.block.0.conv.bias',
            #                                                'module.dec5.block.1.weight','module.dec5.block.1.bias',
            #                                                'module.dec4.block.0.conv.weight','module.dec4.block.0.conv.bias',
            #                                                'module.dec4.block.1.weight','module.dec4.block.1.bias',
            #                                                'module.dec3.block.0.conv.weight','module.dec3.block.0.conv.bias',
            #                                                'module.dec3.block.1.weight','module.dec3.block.1.bias',
            #                                                'module.dec2.block.0.conv.weight','module.dec2.block.0.conv.bias',
            #                                                'module.dec2.block.1.weight','module.dec2.block.1.bias',
            #                                                'module.dec1.conv.weight','module.dec1.conv.bias',
            #                                                'module.final.weight','module.final.bias'])
            #     print_model_summa zvgf    t5y(model)
            # elif epoch == 50:
            #     set_freeze_layers(model, freeze_layer_names=None)
            #     print_model_summay(model)
            for i, (train_image, train_mask,
                    train_mask_ind) in enumerate(train_loader):
                # inputs, targets = variable(inputs), variable(targets)

                train_image = train_image.permute(0, 3, 1, 2)
                train_mask = train_mask.permute(0, 3, 1, 2)
                train_image = train_image.to(device)
                train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                train_mask_ind = train_mask_ind.to(device).type(
                    torch.cuda.FloatTensor)
                # if args.problem_type == 'binary':
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                # else:
                #     #train_mask = train_mask.to(device).type(torch.cuda.LongTensor)
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)

                outputs, outputs_mask_ind1, outputs_mask_ind2 = model(
                    train_image)
                #print(outputs.size())
                #print(outputs_mask_ind1.size())
                #print(outputs_mask_ind2.size())
                ### note that the last layer in the model is defined differently
                # if args.problem_type == 'binary':
                #     train_prob = F.sigmoid(outputs)
                #     loss = criterion(outputs, train_mask)
                # else:
                #     #train_prob = outputs
                #     train_prob = F.sigmoid(outputs)
                #     loss = torch.tensor(0).type(train_mask.type())
                #     for feat_inx in range(train_mask.shape[1]):
                #         loss += criterion(outputs, train_mask)
                train_prob = F.sigmoid(outputs)
                train_mask_ind_prob1 = F.sigmoid(outputs_mask_ind1)
                train_mask_ind_prob2 = F.sigmoid(outputs_mask_ind2)
                loss1 = criterion(outputs, train_mask)
                #loss1 = F.binary_cross_entropy_with_logits(outputs, train_mask)
                #loss2 = nn.BCEWithLogitsLoss()(outputs_mask_ind1, train_mask_ind)
                #print(train_mask_ind.size())
                #weight = torch.ones_like(train_mask_ind)
                #weight[:, 0] = weight[:, 0] * 1
                #weight[:, 1] = weight[:, 1] * 14
                #weight[:, 2] = weight[:, 2] * 14
                #weight[:, 3] = weight[:, 3] * 4
                #weight[:, 4] = weight[:, 4] * 4
                #weight = weight * train_mask_ind + 1
                #weight = weight.to(device).type(torch.cuda.FloatTensor)
                loss2 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind1, train_mask_ind)
                loss3 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind2, train_mask_ind)
                #loss3 = criterion(outputs_mask_ind2, train_mask_ind)
                loss = loss1 * w1 + loss2 * w2 + loss3 * w3
                #print(loss1.item(), loss2.item(), loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                step += 1
                #jaccard += [get_jaccard(train_mask, (train_prob > 0).float()).item()]
                meter.add(train_prob, train_mask, train_mask_ind_prob1,
                          train_mask_ind_prob2, train_mask_ind, loss1.item(),
                          loss2.item(), loss3.item(), loss.item())
                # print(train_mask.data.shape)
                # print(train_mask.data.sum(dim=-2).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).sum(dim=0).shape)
                # intersection = train_mask.data.sum(dim=-2).sum(dim=-1)
                # print(intersection.shape)
                # print(intersection.dtype)
                # print(train_mask.data.shape[0])
                #torch.zeros([2, 4], dtype=torch.float32)
            #########################
            ## at the end of each epoch, evualte the metrics
            epoch_time = time.time() - start_time
            train_metrics = meter.value()
            train_metrics['epoch_time'] = epoch_time
            train_metrics['image'] = train_image.data
            train_metrics['mask'] = train_mask.data
            train_metrics['prob'] = train_prob.data

            #train_jaccard = np.mean(jaccard)
            #train_auc = str(round(mtr1.value()[0],2))+' '+str(round(mtr2.value()[0],2))+' '+str(round(mtr3.value()[0],2))+' '+str(round(mtr4.value()[0],2))+' '+str(round(mtr5.value()[0],2))
            valid_metrics = valid_fn(model, criterion, valid_loader, device,
                                     num_classes)
            ##############
            ## write events
            write_event(log,
                        step,
                        epoch=epoch,
                        train_metrics=train_metrics,
                        valid_metrics=valid_metrics)
            #save_weights(model, model_path, epoch + 1, step)
            #########################
            ## tensorboard
            write_tensorboard(writer,
                              model,
                              epoch,
                              train_metrics=train_metrics,
                              valid_metrics=valid_metrics)
            #########################
            ## save the best model
            valid_loss = valid_metrics['loss1']
            valid_jaccard = valid_metrics['jaccard']
            if valid_loss < previous_valid_loss:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_loss = valid_loss
                print('Save best model by loss')
            if valid_jaccard > previous_valid_jaccard:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_jaccard = valid_jaccard
                print('Save best model by jaccard')
            #########################
            ## change learning rate
            scheduler.step(valid_metrics['loss1'])

        except KeyboardInterrupt:
            # print('--' * 10)
            # print('Ctrl+C, saving snapshot')
            # save_weights(model, model_path, epoch, step)
            # print('done.')
            # print('--' * 10)
            writer.close()
            #return
    writer.close()
示例#3
0
g = (tge.TGE().set_graph_def(gdef).set_devices(devices).data_parallel(
    'ps0').compile().get_graph_def())

tf.reset_default_graph()
tf.import_graph_def(g)
graph = tf.get_default_graph()

x = graph.get_tensor_by_name("import/Placeholder:0")
y = graph.get_tensor_by_name("import/Placeholder_1:0")
opt = graph.get_operation_by_name("import/GradientDescent")
init = graph.get_operation_by_name("import/init")
# currently a hack. Later we will add an API for user to get tensor references back
acc = 10 * (graph.get_tensor_by_name("import/Mean/replica_0:0") +
            graph.get_tensor_by_name("import/Mean/replica_1:0")) / 2

write_tensorboard(opt.graph)

workers = ["10.28.1.26:3901", "10.28.1.25:3901"]
server = setup_workers(workers)

sess = tf.Session(server.target,
                  config=tf.ConfigProto(log_device_placement=True))
sess.run(init)


def onehot(x):
    max = x.max() + 1
    return np.eye(max)[x]


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
示例#4
0
    # .destructify_names()
    # .compile()
    # .get_result()
    .set_bandwidth(100000, 1000).evaluate({
        node.name: [np.random.randint(0, 1000)] * len(devices)
        for node in gdef.node
    }))
print(g)
toc1 = time.perf_counter()

raise SystemExit

tf.reset_default_graph()
tf.import_graph_def(g)
graph = tf.get_default_graph()
write_tensorboard(graph)

x = graph.get_tensor_by_name("import/Placeholder:0")
y = graph.get_tensor_by_name("import/Placeholder_1:0")
opt = graph.get_operation_by_name("import/GradientDescent")
init = graph.get_operation_by_name("import/init")

data = {
    x: np.random.uniform(size=(64, 224, 224, 3)),
    y: np.random.uniform(size=(64, 1000))
}

config = tf.ConfigProto(allow_soft_placement=True)  #log_device_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(server.target, config=config)
sess.run(init)