Exemplo n.º 1
0
def train():
    with fluid.dygraph.guard():
        epoch_num = train_parameters["num_epochs"]
        net = DenseNet("densenet", layers=121, dropout_prob=train_parameters['dropout_prob'],
                       class_dim=train_parameters['class_dim'])
        optimizer = optimizer_rms_setting(net.parameters())
        file_list = os.path.join(train_parameters['data_dir'], "train.txt")
        train_reader = paddle.batch(reader.custom_image_reader(file_list, train_parameters['data_dir'], 'train'),
                                    batch_size=train_parameters['train_batch_size'],
                                    drop_last=True)
        test_reader = paddle.batch(reader.custom_image_reader(file_list, train_parameters['data_dir'], 'val'),
                                   batch_size=train_parameters['train_batch_size'],
                                   drop_last=True)
        if train_parameters["continue_train"]:
            model, _ = fluid.dygraph.load_dygraph(train_parameters["save_persistable_dir"])
            net.load_dict(model)

        best_acc = 0
        for epoch_num in range(epoch_num):

            for batch_id, data in enumerate(train_reader()):
                dy_x_data = np.array([x[0] for x in data]).astype('float32')
                y_data = np.array([x[1] for x in data]).astype('int')
                y_data = y_data[:, np.newaxis]

                img = fluid.dygraph.to_variable(dy_x_data)
                label = fluid.dygraph.to_variable(y_data)
                label.stop_gradient = True
                t1 = time.time()
                out, acc = net(img, label)
                t2 = time.time()
                forward_time = t2 - t1
                loss = fluid.layers.cross_entropy(out, label)
                avg_loss = fluid.layers.mean(loss)
                # dy_out = avg_loss.numpy()
                t3 = time.time()
                avg_loss.backward()
                t4 = time.time()
                backward_time = t4 - t3
                optimizer.minimize(avg_loss)
                net.clear_gradients()
                # print(forward_time, backward_time)

                dy_param_value = {}
                for param in net.parameters():
                    dy_param_value[param.name] = param.numpy

                if batch_id % 40 == 0:
                    logger.info("Loss at epoch {} step {}: {}, acc: {}".format(epoch_num, batch_id, avg_loss.numpy(),
                                                                               acc.numpy()))

            net.eval()
            epoch_acc = eval_net(test_reader, net)
            net.train()
            if epoch_acc > best_acc:
                fluid.dygraph.save_dygraph(net.state_dict(), train_parameters["save_persistable_dir"])
                fluid.dygraph.save_dygraph(optimizer.state_dict(), train_parameters["save_persistable_dir"])
                best_acc = epoch_acc
                logger.info("model saved at epoch {}, best accuracy is {}".format(epoch_num, best_acc))
        logger.info("Final loss: {}".format(avg_loss.numpy()))
Exemplo n.º 2
0
net = DenseNet(201, args.num_classes, args.dropout).cuda()
#net = ResNet(152, 10).cuda()
parallel_net = DataParallel(net)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        # normal(m.weight.data, 0, 0.02)
        xavier_uniform(m.weight.data)
        # xavier_uniform(m.bias.data)


print "net complete"

net.apply(weights_init)
optimizer = optim.Adam(net.parameters(), args.lr)
# optimizer = optim.Adam([
#     {'params': net.features.parameters(), 'lr': args.lr * 0.1},
#     {'params': net.fc.parameters(), 'lr': args.lr}
# ], weight_decay=0.0005)
# optimizer = optim.SGD([
#     {'params': net.features.parameters(), 'lr':args.lr * 0.1},
#     {'params': net.fc.parameters(), 'lr': args.lr}
# ], weight_decay=5e-6, momentum=0.9, nesterov=True)
criterion = FocalLoss(args.num_classes).cuda()
scheduler = StepLR(optimizer, gamma=0.5, step_size=args.decay_step)

train_meter_loss = tnt.meter.AverageValueMeter()
train_meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
train_confusion_meter = tnt.meter.ConfusionMeter(args.num_classes, normalized=True)
    def run_once(self):
        
        log_dir = self.log_dir

        misc.check_manual_seed(self.seed)
        train_pairs, valid_pairs = dataset.prepare_data_VIABLE_2048()
        print(len(train_pairs))
        # --------------------------- Dataloader

        train_augmentors = self.train_augmentors()
        train_dataset = dataset.DatasetSerial(train_pairs[:],
                        shape_augs=iaa.Sequential(train_augmentors[0]),
                        input_augs=iaa.Sequential(train_augmentors[1]))

        infer_augmentors = self.infer_augmentors()
        infer_dataset = dataset.DatasetSerial(valid_pairs[:],
                        shape_augs=iaa.Sequential(infer_augmentors))

        train_loader = data.DataLoader(train_dataset, 
                                num_workers=self.nr_procs_train, 
                                batch_size=self.train_batch_size, 
                                shuffle=True, drop_last=True)

        valid_loader = data.DataLoader(infer_dataset, 
                                num_workers=self.nr_procs_valid, 
                                batch_size=self.infer_batch_size, 
                                shuffle=True, drop_last=False)

        # --------------------------- Training Sequence

        if self.logging:
            misc.check_log_dir(log_dir)

        device = 'cuda'

        # networks
        input_chs = 3    
        net = DenseNet(input_chs, self.nr_classes)
        net = torch.nn.DataParallel(net).to(device)
        # print(net)

        # optimizers
        optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps)

        # load pre-trained models
        if self.load_network:
            saved_state = torch.load(self.save_net_path)
            net.load_state_dict(saved_state)
        #
        trainer = Engine(lambda engine, batch: self.train_step(net, batch, optimizer, 'cuda'))
        inferer = Engine(lambda engine, batch: self.infer_step(net, batch, 'cuda'))

        train_output = ['loss', 'acc']
        infer_output = ['prob', 'true']
        ##

        if self.logging:
            checkpoint_handler = ModelCheckpoint(log_dir, self.chkpts_prefix, 
                                            save_interval=1, n_saved=120, require_empty=False)
            # adding handlers using `trainer.add_event_handler` method API
            trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
                                    to_save={'net': net}) 

        timer = Timer(average=True)
        timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                            pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
        timer.attach(inferer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                            pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

        # attach running average metrics computation
        # decay of EMA to 0.95 to match tensorpack default
        RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss')
        RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc')

        # attach progress bar
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=['loss'])
        pbar.attach(inferer)

        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EXCEPTION_RAISED)
        def handle_exception(engine, e):
            if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
                engine.terminate()
                warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
                checkpoint_handler(engine, {'net_exception': net})
            else:
                raise e

        # writer for tensorboard logging
        if self.logging:
            writer = SummaryWriter(log_dir=log_dir)
            json_log_file = log_dir + '/stats.json'
            with open(json_log_file, 'w') as json_file:
                json.dump({}, json_file) # create empty file

        @trainer.on(Events.EPOCH_STARTED)
        def log_lrs(engine):
            if self.logging:
                lr = float(optimizer.param_groups[0]['lr'])
                writer.add_scalar("lr", lr, engine.state.epoch)
            # advance scheduler clock
            scheduler.step()

        ####
        def update_logs(output, epoch, prefix, color):
            # print values and convert
            max_length = len(max(output.keys(), key=len))
            for metric in output:
                key = colored(prefix + '-' + metric.ljust(max_length), color)
                print('------%s : ' % key, end='')
                print('%0.7f' % output[metric])
            if 'train' in prefix:
                lr = float(optimizer.param_groups[0]['lr'])
                key = colored(prefix + '-' + 'lr'.ljust(max_length), color)
                print('------%s : %0.7f' % (key, lr))

            if not self.logging:
                return

            # create stat dicts
            stat_dict = {}
            for metric in output:
                metric_value = output[metric] 
                stat_dict['%s-%s' % (prefix, metric)] = metric_value

            # json stat log file, update and overwrite
            with open(json_log_file) as json_file:
                json_data = json.load(json_file)

            current_epoch = str(epoch)
            if current_epoch in json_data:
                old_stat_dict = json_data[current_epoch]
                stat_dict.update(old_stat_dict)
            current_epoch_dict = {current_epoch : stat_dict}
            json_data.update(current_epoch_dict)

            with open(json_log_file, 'w') as json_file:
                json.dump(json_data, json_file)

            # log values to tensorboard
            for metric in output:
                writer.add_scalar(prefix + '-' + metric, output[metric], current_epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_train_running_results(engine):
            """
            running training measurement
            """
            training_ema_output = engine.state.metrics #
            update_logs(training_ema_output, engine.state.epoch, prefix='train-ema', color='green')

        ####
        def get_init_accumulator(output_names):
            return {metric : [] for metric in output_names}

        import cv2
        def process_accumulated_output(output):
            def uneven_seq_to_np(seq, batch_size=self.infer_batch_size):
                if self.infer_batch_size == 1:
                    return np.squeeze(seq)
                    
                item_count = batch_size * (len(seq) - 1) + len(seq[-1])
                cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype)
                for idx in range(0, len(seq)-1):
                    cat_array[idx   * batch_size : 
                            (idx+1) * batch_size] = seq[idx] 
                cat_array[(idx+1) * batch_size:] = seq[-1]
                return cat_array
            #
            prob = uneven_seq_to_np(output['prob'])
            true = uneven_seq_to_np(output['true'])

            # cmap = plt.get_cmap('jet')
            # epi = prob[...,1]
            # epi = (cmap(epi) * 255.0).astype('uint8')
            # cv2.imwrite('sample.png', cv2.cvtColor(epi, cv2.COLOR_RGB2BGR))

            pred = np.argmax(prob, axis=-1)
            true = np.squeeze(true)

            # deal with ignore index
            pred = pred.flatten()
            true = true.flatten()
            pred = pred[true != 0] - 1
            true = true[true != 0] - 1

            acc = np.mean(pred == true)
            inter = (pred * true).sum()
            total = (pred + true).sum()
            dice = 2 * inter / total
            #
            proc_output = dict(acc=acc, dice=dice)
            return proc_output

        @trainer.on(Events.EPOCH_COMPLETED)
        def infer_valid(engine):
            """
            inference measurement
            """
            inferer.accumulator = get_init_accumulator(infer_output)
            inferer.run(valid_loader)
            output_stat = process_accumulated_output(inferer.accumulator)
            update_logs(output_stat, engine.state.epoch, prefix='valid', color='red')

        @inferer.on(Events.ITERATION_COMPLETED)
        def accumulate_outputs(engine):
            batch_output = engine.state.output
            for key, item in batch_output.items():
                engine.accumulator[key].extend([item])
        ###
        #Setup is done. Now let's run the training
        trainer.run(train_loader, self.nr_epochs)
        return
Exemplo n.º 4
0
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

#model definition and training parameters
net = DenseNet()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor=0.05,
                                                       patience=5,
                                                       verbose=True,
                                                       threshold=0.001)

if train:
    print('Starting training...')
    best_loss = np.inf
    losses = np.zeros(epochs)
    accs = np.zeros(epochs)
    val_losses = np.zeros(epochs)
    val_accs = np.zeros(epochs)
    for epoch in range(epochs):  # loop over the dataset multiple times