Beispiel #1
0
def train(args,
          train_loader,
          val_loader,
          model,
          model_cpu,
          device,
          criterion,
          optimizer,
          scheduler,
          logger,
          writer,
          regularization=None,
          model_lstm=None):
    record = AverageMeter()
    val_record = AverageMeter()

    model.train()

    if val_loader is not None:
        val_loader_itr = iter(val_loader)

    start = time.time()
    for iteration, data in enumerate(train_loader):
        print('time taken for itr: ', time.time() - start)
        start = time.time()
        sys.stdout.flush()

        if args.in_channel == 2:
            (_, volume, input_label, label, class_weight, _) = data
            volume, input_label, label = volume.to(device), input_label.to(
                device), label.to(device)
            output = model(torch.cat((volume, input_label), 1))
        else:
            (_, volume, label, class_weight, _) = data

            volume, label = volume.to(device), label.to(device)
            output = model(volume)
            if model_lstm is not None:
                output_pre_lstm = output.clone().detach()
                output = model_lstm(output.clone())

        class_weight = class_weight.to(device)
        if regularization is not None:
            loss = criterion(output, label,
                             class_weight) + regularization(output)
        else:
            loss = criterion(output, label, class_weight)

        record.update(loss, args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.write("[Volume %d] train_loss=%0.4f lr=%.5f\n" %
                     (iteration, loss.item(), optimizer.param_groups[0]['lr']))
        print('[Iteration %d] train_loss=%0.4f lr=%.6f' %
              (iteration, loss.item(), optimizer.param_groups[0]['lr']))
        writer.add_scalars('Loss', {'Train': loss.item()}, iteration)

        if iteration % 50 == 0 and iteration >= 1:

            params_gpu = model.named_parameters()
            params_cpu = model_cpu.named_parameters()
            dict_p = dict(params_gpu)
            dict_params_cpu = dict(params_cpu)
            for name, param in dict_p.items():
                dict_params_cpu[name].data.copy_(param.data)

            if args.task == 0:
                visualize_aff(volume,
                              label,
                              output,
                              iteration,
                              writer,
                              mode='Train')
            elif args.task == 1 or args.task == 3:
                if args.in_channel == 2:
                    visualize(volume,
                              label,
                              output,
                              iteration,
                              writer,
                              input_label=input_label)
                else:
                    visualize(volume, label, output, iteration, writer)

            scheduler.step(record.avg)
            record.reset()

            if val_loader is not None:
                model.eval()
                val_record.reset()

                #for better coverage of validation dataset running multiple batches of val
                for _ in range(5):
                    try:
                        (_, volume, label, class_weight,
                         _) = next(val_loader_itr)
                    except StopIteration:
                        val_loader_itr = iter(val_loader)
                        (_, volume, label, class_weight,
                         _) = next(val_loader_itr)

                    with torch.no_grad():
                        volume, label = volume.to(device), label.to(device)
                        class_weight = class_weight.to(device)
                        output = model(volume)

                        if regularization is not None:
                            val_loss = criterion(
                                output, label,
                                class_weight) + regularization(output)
                        else:
                            val_loss = criterion(output, label, class_weight)

                        val_record.update(val_loss, args.batch_size)

                writer.add_scalars('Loss', {'Val': val_loss.item()}, iteration)
                print('[Iteration %d] val_loss=%0.4f lr=%.6f' %
                      (iteration, val_loss.item(),
                       optimizer.param_groups[0]['lr']))

                if args.task == 0:
                    visualize_aff(volume,
                                  label,
                                  output,
                                  iteration,
                                  writer,
                                  mode='Validation')
                elif args.task == 1 or args.task == 3:
                    visualize(volume, label, output, iteration, writer)

                model.train()

            #print('weight factor: ', weight_factor) # debug
            # debug
            # if iteration < 50:
            #     fl = h5py.File('debug_%d_h5' % (iteration), 'w')
            #     output = label[0].cpu().detach().numpy().astype(np.uint8)
            #     print(output.shape)
            #     fl.create_dataset('main', data=output)
            #     fl.close()

        #Save model
        if iteration % args.iteration_save == 0 or iteration >= args.iteration_total:
            torch.save(model.state_dict(),
                       args.output + (args.exp_name + '_%d.pth' % iteration))
            if model_lstm is not None:
                torch.save(
                    model_lstm.state_dict(), args.output +
                    (args.exp_name + '_headLSTM_%d.pth' % iteration))

        # Terminate
        if iteration >= args.iteration_total:
            break
def train(args,
          train_loader,
          model,
          device,
          criterion,
          optimizer,
          scheduler,
          logger,
          writer,
          regularization=None):
    record = AverageMeter()
    model.train()

    # for iteration, (_, volume, label, class_weight, _) in enumerate(train_loader):
    for iteration, batch in enumerate(train_loader):

        if args.task == 22:
            _, volume, seg_mask, class_weight, _, label, out_skeleton = batch
        else:
            _, volume, label, class_weight, _ = batch
        volume, label = volume.to(device), label.to(device)
        #        seg_mask = seg_mask.to(device)
        class_weight = class_weight.to(device)
        output = model(volume)
        print('output_shape:', output.shape)
        print('label_shape:', label.shape)
        print('unique_output_ids_0:',
              len(np.unique(np.array(output).astype(uint8)[0])))
        print('unique_output_ids_1:',
              len(np.unique(np.array(output).astype(uint8)[1])))
        print('unique_output_ids_2:',
              len(np.unique(np.array(output).astype(uint8)[2])))

        print('unique_label_ids_0:',
              len(np.unique(np.array(label).astype(uint8)[0])))
        print('unique_label_ids_1:',
              len(np.unique(np.array(output).astype(uint8)[1])))
        print('unique_label_ids_2:',
              len(np.unique(np.array(output).astype(uint8)[2])))

        if regularization is not None:
            loss = criterion(output, label,
                             class_weight) + regularization(output)
        else:
            loss = criterion(output, label, class_weight)
        record.update(loss, args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.write("[Volume %d] train_loss=%0.4f lr=%.5f\n" % (iteration, \
                loss.item(), optimizer.param_groups[0]['lr']))

        if iteration % 10 == 0 and iteration >= 1:
            writer.add_scalar('Loss', record.avg, iteration)
            print('[Iteration %d] train_loss=%0.4f lr=%.6f' % (iteration, \
                  record.avg, optimizer.param_groups[0]['lr']))
            scheduler.step(record.avg)
            record.reset()
            if args.task == 0:
                visualize_aff(volume, label, output, iteration, writer)
            elif args.task == 1 or args.task == 2 or args.task == 22 or args.task == 11:
                visualize(volume, label, output, iteration, writer)
            #print('weight factor: ', weight_factor) # debug
            # debug
            # if iteration < 50:
            #     fl = h5py.File('debug_%d_h5' % (iteration), 'w')
            #     output = label[0].cpu().detach().numpy().astype(np.uint8)
            #     print(output.shape)
            #     fl.create_dataset('main', data=output)
            #     fl.close()

        #Save model
        if iteration % args.iteration_save == 0 or iteration >= args.iteration_total:
            torch.save(model.state_dict(),
                       args.output + ('/volume_%d.pth' % (iteration)))

        # Terminate
        if iteration >= args.iteration_total:
            break  #
def train_2d(args, train_loader, model, device, criterion,
          optimizer, scheduler, logger, writer, regularization=None):
    record = AverageMeter()
    model.train()

    # for iteration, (_, volume, label, class_weight, _) in enumerate(train_loader):
    for iteration, batch in enumerate(train_loader):
        iteration = iteration + args.iteration_begin
        # print('begin:',iteration)
        if args.task == 22:
            # _, volume, seg_mask, class_weight, _, label, out_skeleton, out_valid = batch
            if args.valid_mask is None:
                _, volume, seg_mask, class_weight, _, label, out_skeleton = batch
            else:
                _, volume, seg_mask, class_weight, _, label, out_skeleton, out_valid = batch
                out_valid = out_valid.to(device)
        else:
            _, volume, label, class_weight, _ = batch
        volume, label = volume.to(device), label.to(device)
        #print(volume.shape,label.shape)
        volume = volume.squeeze(2)
        label = label.squeeze(2)
        #        seg_mask = seg_mask.to(device)
        class_weight = class_weight.to(device)
        class_weight = class_weight.squeeze(2)
        if args.ebd == 1:
            p = 5
            noisy_label = (label.cpu().numpy() + np.random.binomial(1, float(p) / 100.0, (320, 320))) % 2
            noisy_label = torch.Tensor(noisy_label).cuda()
            # print('getting:')
            output = model(volume, noisy_label)
            # print('got:')
        else:
            output = model(volume)

        if args.task == 22 and args.valid_mask is not None:
            class_weight = class_weight * out_valid

        if regularization is not None:
            loss = criterion(output, label, class_weight) + regularization(output)
        else:
            loss = criterion(output, label, class_weight)
            # print('loss:')
        record.update(loss, args.batch_size)

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.write("[Volume %d] train_loss=%0.4f lr=%.5f\n" % (iteration, \
                                                                 loss.item(), optimizer.param_groups[0]['lr']))

        if iteration % 10 == 0 and iteration >= 1:
            writer.add_scalar('Loss', record.avg, iteration)
            print('[Iteration %d] train_loss=%0.4f lr=%.6f' % (iteration, \
                                                               record.avg, optimizer.param_groups[0]['lr']))
            scheduler.step(record.avg)
            record.reset()
            if args.task == 0:
                visualize_aff(volume, label, output, iteration, writer)
            elif args.task == 1 or args.task == 2 or args.task == 22:
                visualize(volume, label, output, iteration, writer, aux=args.aux)
            # print('weight factor: ', weight_factor) # debug
            # debug
            # if iteration < 50:
            #     fl = h5py.File('debug_%d_h5' % (iteration), 'w')
            #     output = label[0].cpu().detach().numpy().astype(np.uint8)
            #     print(output.shape)
            #     fl.create_dataset('main', data=output)
            #     fl.close()

        # Save model
        if iteration % args.iteration_save == 0 or iteration >= args.iteration_total:
            torch.save(model.state_dict(), args.output + ('/volume_%d.pth' % (iteration)))

        # Terminate
        if iteration >= args.iteration_total:
            break  #
def train_dsc(args,
              train_loader,
              model,
              device,
              criterion,
              optimizer,
              scheduler,
              logger,
              writer,
              regularization=None):
    record = AverageMeter()
    model.train()
    print('train_dsc')
    # for iteration, (_, volume, label, class_weight, _) in enumerate(train_loader):
    for iteration, batch in enumerate(train_loader):
        iteration = iteration + args.iteration_begin
        #print('begin:',iteration)
        if args.task == 22:
            # _, volume, seg_mask, class_weight, _, label, out_skeleton, out_valid = batch
            if args.valid_mask is None:
                _, volume, seg_mask, class_weight, _, label, out_skeleton = batch
            else:
                _, volume, seg_mask, class_weight, _, label, out_skeleton, out_valid = batch
                out_valid = out_valid.to(device)
        else:
            _, volume, label, class_weight, _ = batch
        volume, label = volume.to(device), label.to(device)
        #        seg_mask = seg_mask.to(device)
        class_weight = class_weight.to(device)

        seg_out, ins_out = model(volume)

        ins_criterion = DiscriminativeLoss(device)

        label_0 = 1 - label
        ins_label = torch.cat((label_0, label), 1)
        bs, ch, d, h, w = ins_out.shape
        ins_label = ins_label.contiguous().view(bs, 2, d * h * w)
        ins_out = ins_out.contiguous().view(bs, ch, d * h * w)

        ins_loss = ins_criterion(ins_out, ins_label, [2] * bs)

        seg_loss = criterion(seg_out, label, class_weight)

        loss = ins_loss + seg_loss
        record.update(loss, args.batch_size)

        output = seg_out
        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.write("[Volume %d] train_loss=%0.4f lr=%.5f\n" % (iteration, \
                                                                 loss.item(), optimizer.param_groups[0]['lr']))

        if iteration % 10 == 0 and iteration >= 1:
            writer.add_scalar('Loss', record.avg, iteration)
            print('[Iteration %d] train_loss=%0.4f lr=%.6f' % (iteration, \
                                                               record.avg, optimizer.param_groups[0]['lr']))
            scheduler.step(record.avg)
            record.reset()
            if args.task == 0:
                visualize_aff(volume, label, output, iteration, writer)
            elif args.task == 1 or args.task == 2 or args.task == 22:
                visualize(volume,
                          label,
                          output,
                          iteration,
                          writer,
                          aux=args.aux)
            # print('weight factor: ', weight_factor) # debug
            # debug
            # if iteration < 50:
            #     fl = h5py.File('debug_%d_h5' % (iteration), 'w')
            #     output = label[0].cpu().detach().numpy().astype(np.uint8)
            #     print(output.shape)
            #     fl.create_dataset('main', data=output)
            #     fl.close()

        # Save model
        if iteration % args.iteration_save == 0 or iteration >= args.iteration_total:
            torch.save(model.state_dict(),
                       args.output + ('/volume_%d.pth' % (iteration)))

        # Terminate
        if iteration >= args.iteration_total:
            break  #