model = model.to(device)

        PAD = cfg.TRAIN.pad
        img_list = os.listdir(base_path)
        for f_img in img_list:
            print('Inference: ' + f_img, end=' ')
            # raw = np.asarray(Image.open(os.path.join(base_path, f_img)).convert('L'))
            raw = np.asarray(Image.open(os.path.join(base_path, f_img)))
            raw = raw.transpose(2, 0, 1)
            if cfg.TRAIN.track == 'complex':
                if raw.shape[0] == 9959 or raw.shape[0] == 9958:
                    raw_ = np.zeros((10240,10240), dtype=np.uint8)
                    raw_[141:141+9959, 141:141+9958] = raw
                    raw = raw_
                    del raw_
            crop_img = Crop_image(raw,crop_size=cfg.TRAIN.crop_size,overlap=cfg.TRAIN.overlap)
            start = time.time()
            for i in range(crop_img.num):
                for j in range(crop_img.num):
                    raw_crop = crop_img.gen(i, j)
                    #########
                    #inference
                    if crop_img.dim == 3:
                        raw_crop_ = raw_crop[np.newaxis, :, :, :]
                    else:
                        raw_crop_ = raw_crop[np.newaxis, np.newaxis, :, :]
                    inputs = torch.Tensor(raw_crop_).to(device)
                    inputs = F.pad(inputs, (PAD, PAD, PAD, PAD))
                    with torch.no_grad():
                        pred = model(inputs)
                    # pred = pred[0]
def loop(cfg, train_provider, valid_provider, model, criterion, optimizer, iters, writer):
    PAD = cfg.TRAIN.pad
    f_loss_txt = open(os.path.join(cfg.record_path, 'loss.txt'), 'a')
    f_valid_txt = open(os.path.join(cfg.record_path, 'valid.txt'), 'a')
    rcd_time = []
    sum_time = 0
    sum_loss = 0
    valid_score = []
    valid_score_tmp = []
    thresd = cfg.TRAIN.thresd
    most_f1 = 0
    most_iters = 0
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    while iters <= cfg.TRAIN.total_iters:
        # train
        iters += 1
        t1 = time.time()
        input, target = train_provider.next()
        
        # decay learning rate
        if cfg.TRAIN.end_lr == cfg.TRAIN.base_lr:
            current_lr = cfg.TRAIN.base_lr
        else:
            current_lr = calculate_lr(iters)
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr
        
        optimizer.zero_grad()
        input = F.pad(input, (PAD, PAD, PAD, PAD))
        pred = model(input)
        pred = F.pad(pred, (-PAD, -PAD, -PAD, -PAD))
        # target = torch.unsqueeze(target, dim=1)
        # if iters == 1:
        #     writer.add_graph(model, (input,))

        loss = criterion(pred, target)
        loss.backward()
        if cfg.TRAIN.weight_decay is not None:
            for group in optimizer.param_groups:
                for param in group['params']:
                    param.data = param.data.add(-cfg.TRAIN.weight_decay * group['lr'], param.data)
        optimizer.step()
        
        sum_loss += loss.item()
        sum_time += time.time() - t1
        
        # log train
        if iters % cfg.TRAIN.display_freq == 0:
            rcd_time.append(sum_time)
            logging.info('step %d, loss = %.4f (wt: *10, lr: %.8f, et: %.2f sec, rd: %.2f min)'
                            % (iters, sum_loss / cfg.TRAIN.display_freq * 10, current_lr, sum_time,
                            (cfg.TRAIN.total_iters - iters) / cfg.TRAIN.display_freq * np.mean(np.asarray(rcd_time)) / 60))
            writer.add_scalar('loss', sum_loss / cfg.TRAIN.display_freq * 10, iters)
            f_loss_txt.write('step = ' + str(iters) + ', loss = ' + str(sum_loss / cfg.TRAIN.display_freq * 10))
            f_loss_txt.write('\n')
            f_loss_txt.flush()
            sys.stdout.flush()
            sum_time = 0
            sum_loss = 0
        
        # valid
        if iters % cfg.TRAIN.valid_freq == 0:
            input = F.pad(input, (-PAD, -PAD, -PAD, -PAD))
            input0 = (np.squeeze(input[0].data.cpu().numpy()) * 255).astype(np.uint8)
            target = (np.squeeze(target[0].data.cpu().numpy()) * 255).astype(np.uint8)
            pred = np.squeeze(pred[0].data.cpu().numpy())
            pred[pred>1] = 1; pred[pred<0] = 0
            pred = (pred * 255).astype(np.uint8)
            input0 = input0[0]
            im_cat = np.concatenate([input0, pred, target], axis=1)
            Image.fromarray(im_cat).save(os.path.join(cfg.cache_path, '%06d.png' % iters))
        
        # save
        if iters % cfg.TRAIN.save_freq == 0:
            if cfg.TRAIN.if_valid:
                for k in range(valid_provider.data.num):
                    raw, label = valid_provider.data.gen(k)
                    crop_img = Crop_image(raw,crop_size=cfg.TRAIN.crop_size,overlap=cfg.TRAIN.overlap)
                    for i in range(crop_img.num):
                        for j in range(crop_img.num):
                            raw_crop = crop_img.gen(i, j)
                            #########
                            #inference
                            if crop_img.dim == 3:
                                raw_crop_ = raw_crop[np.newaxis, :, :, :]
                            else:
                                raw_crop_ = raw_crop[np.newaxis, np.newaxis, :, :]
                            inputs = torch.Tensor(raw_crop_).to(device)
                            inputs = F.pad(inputs, (PAD, PAD, PAD, PAD))
                            with torch.no_grad():
                                pred = model(inputs)
                            pred = F.pad(pred, (-PAD, -PAD, -PAD, -PAD))
                            pred = pred.data.cpu().numpy()
                            pred = np.squeeze(pred)
                            #########
                            crop_img.save(i, j, pred)
                    results = crop_img.result()
                    results[results<=thresd] = 0
                    results[results>thresd] = 1
                    temp_label = label.flatten()
                    temp_result = results.flatten()
                    f1_common = f1_score(1 - temp_label, 1 - temp_result)
                    f1 = 0
                    results_img = (results * 255).astype(np.uint8)
                    label_img = (label * 255).astype(np.uint8)
                    im_cat_valid = np.concatenate([results_img, label_img], axis=1)
                    if k == valid_provider.data.num - 1:
                        Image.fromarray(im_cat_valid).save(os.path.join(cfg.cache_path, 'valid_%06d.png' % iters))
                    valid_score.append(f1)
                    valid_score_tmp.append(f1_common)
                avg_f1_soft = sum(valid_score) / len(valid_score)
                avg_f1_tmp = sum(valid_score_tmp) / len(valid_score_tmp)
                logging.info('step %d, f1_soft = %.6f' % (iters, avg_f1_soft))
                logging.info('step %d, f1_common = %.6f' % (iters, avg_f1_tmp))
                writer.add_scalar('valid', avg_f1_tmp, iters)
                f_valid_txt.write('step %d, f1 = %.6f' % (iters, avg_f1_tmp))
                f_valid_txt.write('\n')
                f_valid_txt.flush()
                sys.stdout.flush()
                if avg_f1_tmp > most_f1:
                    most_f1 = avg_f1_tmp
                    most_iters = iters
                    states = {'current_iter': iters, 'valid_result': avg_f1_tmp,
                            'model_weights': model.state_dict()}
                    torch.save(states, os.path.join(cfg.save_path, 'model.ckpt'))
                    print('***************save modol, when f1 = %.6f and iters = %d.***************' % (most_f1, most_iters))
            else:
                states = {'current_iter': iters, 'valid_result': None,
                        'model_weights': model.state_dict()}
                torch.save(states, os.path.join(cfg.save_path, 'model-%d.ckpt' % iters))
                print('save model-%d' % iters)
Example #3
0
def loop(cfg, train_provider, valid_provider, model, criterion, optimizer, iters, writer):
    PAD = cfg.TRAIN.pad
    f_loss_txt = open(os.path.join(cfg.record_path, 'loss.txt'), 'w')
    f_valid_txt = open(os.path.join(cfg.record_path, 'valid.txt'), 'w')
    rcd_time = []
    sum_time = 0
    sum_loss = 0
    valid_score = []
    valid_score_tmp = []
    thresd = cfg.TRAIN.thresd
    most_f1 = 0
    most_iters = 0
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    ######################################
    # loss_vgg = PerceptualLoss(mode=2).to(device)
    model_vgg = VGG19().to(device)
    cuda_count = torch.cuda.device_count()
    if cuda_count > 1:
        model_vgg = nn.DataParallel(model_vgg)
        print('VGG build on %d GPUs ... ' % cuda_count, flush=True)
    else:
        print('VGG build on a single GPU ... ', flush=True)
    # cuda_count = torch.cuda.device_count()
    # if cuda_count > 1:
    #     if cfg.TRAIN.batch_size % cuda_count == 0:
    #         loss_vgg = nn.DataParallel(loss_vgg)
    ######################################
    # mse_weight = cfg.TRAIN.mse_weight
    vgg_weight = cfg.TRAIN.vgg_weight

    ### loss function
    # if cfg.MODEL.loss_func_logits:
    #     loss_function = F.binary_cross_entropy_with_logits
    # else:
    #     loss_function = F.binary_cross_entropy

    # final_loss = 0
    while iters <= cfg.TRAIN.total_iters:
        
        # train
        iters += 1
        t1 = time.time()
        input, target = train_provider.next()
        
        # decay learning rate
        if cfg.TRAIN.end_lr == cfg.TRAIN.base_lr:
            current_lr = cfg.TRAIN.base_lr
        else:
            current_lr = calculate_lr(iters)
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr
        
        optimizer.zero_grad()
        input = F.pad(input, (PAD, PAD, PAD, PAD))
        model.train()
        pred = model(input)
        # if not cfg.MODEL.loss_func_logits:
        #     for i in range(len(pred)):
        #         pred[i] = torch.sigmoid(pred[i])
        pred_depad = []
        for p in pred:
            pred_depad.append(F.pad(p, (-PAD, -PAD, -PAD, -PAD)))
        del pred
        # target = torch.unsqueeze(target, dim=1)
        # if iters == 1:
        #     writer.add_graph(model, (input,))

        ############################## Compute Loss ########################################
        # if cfg.MODEL.loss_balance_weight:
        #     cur_weight = edge_weight(target)
        #     writer.add_histogram('weight_edge', cur_weight.clone().cpu().data.numpy(), iters)
        # else:
        #     cur_weight = None

        loss1, loss2, loss3 = multiloss(pred_depad, target)
        pred_depad_2 = pred_depad[2][:,1]
        pred_depad_2 = torch.unsqueeze(pred_depad_2, dim=1)
        pred_ = torch.cat([pred_depad_2, pred_depad_2, pred_depad_2], dim=1)
        target_ = torch.cat([target, target, target], dim=1)
        # loss_twonet = 0.25*loss1+loss2+loss3
        loss_twonet = 0.75*(0.5*loss1+loss2)+loss3
        # loss_vgg2 = loss_vgg(pred_, target_)
        out1 = model_vgg(pred_)
        out2 = model_vgg(target_)
        loss_vgg2 = vgg_loss(out1, out2)
        loss = loss_twonet + vgg_weight * loss_vgg2
        
        loss.backward()
        if cfg.TRAIN.weight_decay is not None:
            for group in optimizer.param_groups:
                for param in group['params']:
                    param.data = param.data.add(-cfg.TRAIN.weight_decay * group['lr'], param.data)
        optimizer.step()
        
        sum_loss += loss.item()
        sum_time += time.time() - t1
        
        # log train
        if iters % cfg.TRAIN.display_freq == 0:
            rcd_time.append(sum_time)
            logging.info('step %d, loss = %.4f (wt: *10, lr: %.8f, et: %.2f sec, rd: %.2f min)'
                            % (iters, sum_loss / cfg.TRAIN.display_freq * 10, current_lr, sum_time,
                            (cfg.TRAIN.total_iters - iters) / cfg.TRAIN.display_freq * np.mean(np.asarray(rcd_time)) / 60))
            logging.info('step %d, loss_twonet = %.4f, loss_vgg = %.4f' % (iters, loss_twonet, vgg_weight * loss_vgg2))
            writer.add_scalar('loss/loss1', loss1, iters)
            writer.add_scalar('loss/loss2', loss2, iters)
            writer.add_scalar('loss/loss3', loss3, iters)
            writer.add_scalar('loss/loss_vgg', loss_vgg2, iters)
            writer.add_scalar('final_loss', sum_loss / cfg.TRAIN.display_freq * 10, iters)
            f_loss_txt.write('step = ' + str(iters) + ', loss = ' + str(sum_loss / cfg.TRAIN.display_freq * 10))
            f_loss_txt.write('\n')
            f_loss_txt.flush()
            sys.stdout.flush()
            sum_time = 0
            sum_loss = 0
        
        # valid
        if iters % cfg.TRAIN.valid_freq == 0:
            input = F.pad(input, (-PAD, -PAD, -PAD, -PAD))
            input0 = (np.squeeze(input[0].data.cpu().numpy()) * 255).astype(np.uint8)
            target = (np.squeeze(target[0].data.cpu().numpy()) * 255).astype(np.uint8)
            pred_show = []
            for p in pred_depad:
                temp = np.squeeze(p[0].data.cpu().numpy())
                temp[temp>1] = 1; temp[temp<0] = 0
                temp = (temp * 255).astype(np.uint8)
                pred_show.append(temp)
            # input0 = input0[0]
            white = np.ones_like(pred_show[2][1], dtype=np.uint8)
            im1 = np.concatenate([input0, pred_show[0][1], pred_show[1][1]], axis=1)
            im2 = np.concatenate([target, pred_show[2][1], white], axis=1)
            im_cat = np.concatenate([im1, im2], axis=0)
            Image.fromarray(im_cat).save(os.path.join(cfg.cache_path, '%06d.png' % iters))
        
        # save
        if iters % cfg.TRAIN.save_freq == 0:
            model.eval()
            for k in range(valid_provider.data.num):
                raw, label = valid_provider.data.gen(k)
                crop_img = Crop_image(raw,crop_size=cfg.TRAIN.crop_size,overlap=cfg.TRAIN.overlap)
                for i in range(crop_img.num):
                    for j in range(crop_img.num):
                        raw_crop = crop_img.gen(i, j)
                        #########
                        #inference
                        if crop_img.dim == 3:
                            raw_crop_ = raw_crop[np.newaxis, :, :, :]
                        else:
                            raw_crop_ = raw_crop[np.newaxis, np.newaxis, :, :]
                        inputs = torch.Tensor(raw_crop_).to(device)
                        inputs = F.pad(inputs, (PAD, PAD, PAD, PAD))
                        
                        with torch.no_grad():
                            pred = model(inputs)
                        pred = pred[2]
                        pred = F.pad(pred, (-PAD, -PAD, -PAD, -PAD))
                        pred = F.softmax(pred, dim=1)
                        pred = torch.argmax(pred, dim=1).squeeze(0)
                        pred = pred.data.cpu().numpy()
                        # pred = np.squeeze(pred)
                        # pred = pred[1]
                        #########
                        crop_img.save(i, j, pred)
                results = crop_img.result()
                results[results<=thresd] = 0
                results[results>thresd] = 1
                temp_label = label.flatten()
                temp_result = results.flatten()
                f1_common = f1_score(1 - temp_label, 1 - temp_result)
                f1_soft = 0
                results_img = (results * 255).astype(np.uint8)
                label_img = (label * 255).astype(np.uint8)
                im_cat_valid = np.concatenate([results_img, label_img], axis=1)
                if k == valid_provider.data.num - 1:
                    Image.fromarray(im_cat_valid).save(os.path.join(cfg.cache_path, 'valid_%06d.png' % iters))
                valid_score.append(f1_soft)
                valid_score_tmp.append(f1_common)
            avg_f1 = sum(valid_score) / len(valid_score)
            avg_f1_tmp = sum(valid_score_tmp) / len(valid_score_tmp)
            logging.info('step %d, f1_soft = %.6f' % (iters, avg_f1))
            logging.info('step %d, f1_common = %.6f' % (iters, avg_f1_tmp))
            writer.add_scalar('valid', avg_f1_tmp, iters)
            f_valid_txt.write('step %d, f1 = %.6f' % (iters, avg_f1_tmp))
            f_valid_txt.write('\n')
            f_valid_txt.flush()
            sys.stdout.flush()
            if avg_f1_tmp > most_f1:
                most_f1 = avg_f1_tmp
                most_iters = iters
                states = {'current_iter': iters, 'valid_result': avg_f1_tmp,
                        'model_weights': model.state_dict()}
                torch.save(states, os.path.join(cfg.save_path, 'model.ckpt'))
                print('***************save modol, when f1 = %.6f and iters = %d.***************' % (most_f1, most_iters))