args = parser.parse_args()
if args.modality == 'rgb' and args.num_samples != 0:
    print("number of samples is forced to be 0 when input modality is rgb")
    args.num_samples = 0
if args.modality == 'rgb' and args.max_depth != 0.0:
    print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
    args.max_depth = 0.0
print(args)

fieldnames = [
    'mse', 'rmse', 'absrel', 'lg10', 'mae', 'delta1', 'delta2', 'delta3',
    'data_time', 'gpu_time'
]
best_result = Result()
best_result.set_to_worst()


def main():
    global args, best_result, output_directory, train_csv, test_csv

    sparsifier = None
    max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
    if args.sparsifier == UniformSampling.name:
        sparsifier = UniformSampling(num_samples=args.num_samples,
                                     max_depth=max_depth)
    elif args.sparsifier == SimulatedStereo.name:
        sparsifier = SimulatedStereo(num_samples=args.num_samples,
                                     max_depth=max_depth)

    # create results folder, if not already exists
class logger:
    def __init__(self, args, prepare=True):
        self.args = args

        self.args.save_pred = True

        output_directory = get_folder_name(args)
        self.output_directory = output_directory
        self.best_result = Result()
        self.best_result.set_to_worst()

        if not prepare:
            return
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)
        self.train_csv = os.path.join(output_directory, 'train.csv')
        self.val_csv = os.path.join(output_directory, 'val.csv')
        self.best_txt = os.path.join(output_directory, 'best.txt')

        # backup the source code
        if args.resume == '':
            print("=> creating source code backup ...")
            backup_directory = os.path.join(output_directory, "code_backup")
            self.backup_directory = backup_directory
            print("=> stop source code backup ...")
            # backup_source_code(backup_directory)

            # create new csv files with only header
            with open(self.train_csv, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
            with open(self.val_csv, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
            print("=> finished creating source code backup.")

    def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter,
                          avg_meter):
        if (i + 1) % self.args.print_freq == 0:
            avg = avg_meter.average()
            blk_avg = blk_avg_meter.average()
            print('=> output: {}'.format(self.output_directory))
            print(
                '{split} Epoch: {0} [{1}/{2}]\tlr={lr} '
                't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
                't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
                'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
                'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
                'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
                'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
                'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
                'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
                'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
                'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
                'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
                .format(epoch,
                        i + 1,
                        n_set,
                        lr=lr,
                        blk_avg=blk_avg,
                        average=avg,
                        split=split.capitalize()))
            blk_avg_meter.reset()

    def conditional_save_info(self, split, average_meter, epoch):
        avg = average_meter.average()
        if split == "train":
            csvfile_name = self.train_csv
        elif split == "val":
            csvfile_name = self.val_csv
        elif split == "eval":
            eval_filename = os.path.join(self.output_directory, 'eval.txt')
            self.save_single_txt(eval_filename, avg, epoch)
            return avg
        elif "test" in split:
            return avg
        else:
            raise ValueError("wrong split provided to logger")
        with open(csvfile_name, 'a') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow({
                'epoch': epoch,
                'rmse': avg.rmse,
                'photo': avg.photometric,
                'mae': avg.mae,
                'irmse': avg.irmse,
                'imae': avg.imae,
                'mse': avg.mse,
                'silog': avg.silog,
                'squared_rel': avg.squared_rel,
                'absrel': avg.absrel,
                'lg10': avg.lg10,
                'delta1': avg.delta1,
                'delta2': avg.delta2,
                'delta3': avg.delta3,
                'gpu_time': avg.gpu_time,
                'data_time': avg.data_time
            })
        return avg

    def save_single_txt(self, filename, result, epoch):
        with open(filename, 'w') as txtfile:
            txtfile.write(
                ("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" +
                 "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" +
                 "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" +
                 "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" +
                 "t_gpu={:.4f}").format(self.args.rank_metric, epoch,
                                        result.rmse, result.mae, result.silog,
                                        result.squared_rel, result.irmse,
                                        result.imae, result.mse, result.absrel,
                                        result.lg10, result.delta1,
                                        result.gpu_time))

    def save_best_txt(self, result, epoch):
        self.save_single_txt(self.best_txt, result, epoch)

    def _get_img_comparison_name(self, mode, epoch, is_best=False):
        if mode == 'eval':
            return self.output_directory + '/comparison_eval.png'
        if mode == 'val':
            if is_best:
                return self.output_directory + '/comparison_best.png'
            else:
                return self.output_directory + '/comparison_' + str(
                    epoch) + '.png'

    def conditional_save_img_comparison(self, mode, i, ele, pred, epoch):
        # save 8 images for visualization
        if mode == 'val' or mode == 'eval':
            skip = 100
            if i == 0:
                self.img_merge = vis_utils.merge_into_row(ele, pred)
            elif i % skip == 0 and i < 8 * skip:
                row = vis_utils.merge_into_row(ele, pred)
                self.img_merge = vis_utils.add_row(self.img_merge, row)
            elif i == 8 * skip:
                filename = self._get_img_comparison_name(mode, epoch)
                vis_utils.save_image(self.img_merge, filename)

    def save_img_comparison_as_best(self, mode, epoch):
        if mode == 'val':
            filename = self._get_img_comparison_name(mode, epoch, is_best=True)
            vis_utils.save_image(self.img_merge, filename)

    def get_ranking_error(self, result):
        return getattr(result, self.args.rank_metric)

    def rank_conditional_save_best(self, mode, result, epoch):
        error = self.get_ranking_error(result)
        best_error = self.get_ranking_error(self.best_result)
        is_best = error < best_error
        if is_best and mode == "val":
            self.old_best_result = self.best_result
            self.best_result = result
            self.save_best_txt(result, epoch)
        return is_best

    def conditional_save_pred(self, mode, file_name, pred, epoch):
        if ("test" in mode or mode == "eval") and self.args.save_pred:

            # save images for visualization/ testing
            image_folder = os.path.join(self.output_directory,
                                        mode + "_output")
            if not os.path.exists(image_folder):
                os.makedirs(image_folder)
            img = torch.squeeze(pred.data.cpu()).numpy()
            file_path = os.path.join(image_folder, file_name)
            vis_utils.save_depth_as_uint16png(img, file_path)

    def conditional_summarize(self, mode, avg, is_best):
        print("\n*\nSummary of ", mode, "round")
        print(''
              'RMSE={average.rmse:.3f}\n'
              'MAE={average.mae:.3f}\n'
              'Photo={average.photometric:.3f}\n'
              'iRMSE={average.irmse:.3f}\n'
              'iMAE={average.imae:.3f}\n'
              'squared_rel={average.squared_rel}\n'
              'silog={average.silog}\n'
              'Delta1={average.delta1:.3f}\n'
              'REL={average.absrel:.3f}\n'
              'Lg10={average.lg10:.3f}\n'
              't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time))
        if is_best and mode == "val":
            print("New best model by %s (was %.3f)" %
                  (self.args.rank_metric,
                   self.get_ranking_error(self.old_best_result)))
        elif mode == "val":
            print("(best %s is %.3f)" %
                  (self.args.rank_metric,
                   self.get_ranking_error(self.best_result)))
        print("*\n")
예제 #3
0
class logger:
    def __init__(self, args, prepare=True):
        self.args = args
        output_directory = get_folder_name(args)
        self.output_directory = output_directory
        self.best_result = Result()
        self.best_result.set_to_worst()

        self.best_result_intensity = Result_intensity()
        self.best_result_intensity.set_to_worst()

        #visual
        # self.viz = Visdom(server='http://10.5.40.31', port=11207)
        # assert self.viz.check_connection()
        from datetime import datetime, timedelta, timezone
        utc_dt = datetime.utcnow().replace(tzinfo=timezone.utc)
        cn_dt = utc_dt.astimezone(timezone(timedelta(hours=8)))
        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
        name_prefix = "log_lr{}bt{}decay{}wi{}wpure{}lradj{}/".format(
            args.lr, args.batch_size, args.weight_decay, args.wi, args.wpure,
            args.lradj)
        #name_prefix = "log_lr{}bt{}decay{}_ws_{}_alphabeta_{}_{}/".format(args.lr ,args.batch_size, args.weight_decay, args.ws, args.alpha, args.beta)
        print(name_prefix + TIMESTAMP)
        self.writer = SummaryWriter(name_prefix + TIMESTAMP)

        if not prepare:
            return
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)
        self.train_csv = os.path.join(output_directory, 'train.csv')
        self.val_csv = os.path.join(output_directory, 'val.csv')
        self.best_txt = os.path.join(output_directory, 'best.txt')
        self.train_csv_intensity = os.path.join(output_directory,
                                                'train_intensity.csv')
        self.val_csv_intensity = os.path.join(output_directory,
                                              'val_intensity.csv')

        # backup the source code
        if args.resume == '':
            print("=> creating source code backup ...")
            backup_directory = os.path.join(output_directory, "code_backup")
            self.backup_directory = backup_directory
            backup_source_code(backup_directory)
            # create new csv files with only header
            with open(self.train_csv, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
            with open(self.val_csv, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
            with open(self.train_csv_intensity, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
            with open(self.val_csv_intensity, 'w') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
            print("=> finished creating source code backup.")

    def conditional_print(self,
                          split,
                          i,
                          epoch,
                          lr,
                          n_set,
                          blk_avg_meter,
                          avg_meter,
                          typeDI="depth"):
        if (i + 1) % self.args.print_freq == 0:
            avg = avg_meter.average()
            blk_avg = blk_avg_meter.average()
            print('=> output: {}'.format(self.output_directory))
            print(
                '{typeDI} {split} Epoch: {0} [{1}/{2}]\tlr={lr} '
                't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
                't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
                'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
                'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
                'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
                'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
                'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
                'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
                'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
                'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
                'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
                .format(epoch,
                        i + 1,
                        n_set,
                        lr=lr,
                        blk_avg=blk_avg,
                        average=avg,
                        split=split.capitalize(),
                        typeDI=typeDI))
            blk_avg_meter.reset()

    def conditional_save_info(self, split, average_meter, epoch):
        avg = average_meter.average()
        if split == "train":
            csvfile_name = self.train_csv
        elif split == "val":
            csvfile_name = self.val_csv
        elif split == "eval":
            eval_filename = os.path.join(self.output_directory, 'eval.txt')
            self.save_single_txt(eval_filename, avg, epoch)
            return avg
        elif "test" in split:
            return avg
        else:
            raise ValueError("wrong split provided to logger")
        with open(csvfile_name, 'a') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow({
                'epoch': epoch,
                'rmse': avg.rmse,
                'photo': avg.photometric,
                'mae': avg.mae,
                'irmse': avg.irmse,
                'imae': avg.imae,
                'mse': avg.mse,
                'silog': avg.silog,
                'squared_rel': avg.squared_rel,
                'absrel': avg.absrel,
                'lg10': avg.lg10,
                'delta1': avg.delta1,
                'delta2': avg.delta2,
                'delta3': avg.delta3,
                'gpu_time': avg.gpu_time,
                'data_time': avg.data_time
            })
        return avg

    def conditional_save_info_intensity(self, split, average_meter, epoch):
        avg = average_meter.average()
        if split == "train":
            csvfile_name = self.train_csv_intensity
        elif split == "val":
            csvfile_name = self.val_csv_intensity
        elif split == "eval":
            eval_filename = os.path.join(self.output_directory, 'eval.txt')
            self.save_single_txt(eval_filename, avg, epoch)
            return avg
        elif "test" in split:
            return avg
        else:
            raise ValueError("wrong split provided to logger")
        with open(csvfile_name, 'a') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow({
                'epoch': epoch,
                'rmse': avg.rmse,
                'photo': avg.photometric,
                'mae': avg.mae,
                'irmse': avg.irmse,
                'imae': avg.imae,
                'mse': avg.mse,
                'silog': avg.silog,
                'squared_rel': avg.squared_rel,
                'absrel': avg.absrel,
                'lg10': avg.lg10,
                'delta1': avg.delta1,
                'delta2': avg.delta2,
                'delta3': avg.delta3,
                'gpu_time': avg.gpu_time,
                'data_time': avg.data_time
            })
        return avg

    def save_single_txt(self, filename, result, epoch):
        with open(filename, 'w') as txtfile:
            txtfile.write(
                ("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" +
                 "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" +
                 "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" +
                 "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" +
                 "t_gpu={:.4f}").format(self.args.rank_metric, epoch,
                                        result.rmse, result.mae, result.silog,
                                        result.squared_rel, result.irmse,
                                        result.imae, result.mse, result.absrel,
                                        result.lg10, result.delta1,
                                        result.gpu_time))

    def save_single_txt_with_intensity(self, filename, result,
                                       result_intensity, epoch):
        with open(filename, 'w') as txtfile:
            txtfile.write(
                ("rank_metric={}\n" + "epoch={}\n" + "depth_rmse={:.3f}\n" +
                 "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" +
                 "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" +
                 "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" +
                 "t_gpu={:.4f}\n" + "wi={}\n" + "btsize={}\n" +
                 "intensity_rmse={:.3f}\n" + "intensity_irmse={:.3f}\n" +
                 "intensity_mae={:.3f}\n" + "intensity_imae={:.3f}\n").format(
                     self.args.rank_metric, epoch, result.rmse, result.mae,
                     result.silog, result.squared_rel, result.irmse,
                     result.imae, result.mse, result.absrel, result.lg10,
                     result.delta1, result.gpu_time, self.args.wi,
                     self.args.batch_size, result_intensity.rmse,
                     result_intensity.irmse, result_intensity.mae,
                     result_intensity.imae))

    def save_best_txt(self, result, epoch):
        self.save_single_txt(self.best_txt, result, epoch)

    def save_best_txt_with_intensity(self, result, result_intensity, epoch):
        self.save_single_txt_with_intensity(self.best_txt, result,
                                            result_intensity, epoch)

    def _get_img_comparison_name(self, mode, epoch, is_best=False):
        if mode == 'eval':
            return self.output_directory + '/comparison_eval.png'
        if mode == 'val':
            if is_best:
                return self.output_directory + '/comparison_best.png'
            else:
                return self.output_directory + '/comparison_' + str(
                    epoch) + '.png'

    # ele-batch_data
    def conditional_save_img_comparison(self, mode, i, ele, pred, epoch):
        # save 8 images for visualization
        if mode == 'val' or mode == 'eval':
            skip = 100
            if i == 0:
                self.img_merge = vis_utils.merge_into_row(ele, pred)
            elif i % skip == 0 and i < 8 * skip:
                row = vis_utils.merge_into_row(ele, pred)
                self.img_merge = vis_utils.add_row(self.img_merge, row)
            elif i == 8 * skip:
                filename = self._get_img_comparison_name(mode, epoch)
                vis_utils.save_image(self.img_merge, filename)  #HWC
                # input C x H x W
                img_np_rescale = skimage.transform.rescale(np.array(
                    self.img_merge, dtype='float64'),
                                                           0.5,
                                                           order=0)
                img_np_rescale_CHW = np.transpose(img_np_rescale, (2, 0, 1))
                self.writer.add_image('eval_comparison', img_np_rescale_CHW, i)

    def conditional_save_img_comparison_with_intensity(self, mode, i, ele,
                                                       pred, pred_intensity,
                                                       epoch):
        # save 8 images for visualization
        if mode == 'val' or mode == 'eval':
            skip = 100
            if i == 0:
                self.img_merge = vis_utils.merge_into_row_with_intensity(
                    ele, pred, pred_intensity)
                # self.img_merge = vis_utils.merge_into_row(ele, pred)
            elif i % skip == 0 and i < 8 * skip:
                # row = vis_utils.merge_into_row(ele, pred)
                row = vis_utils.merge_into_row_with_intensity(
                    ele, pred, pred_intensity)
                self.img_merge = vis_utils.add_row(self.img_merge, row)
            elif i == 8 * skip:
                filename = self._get_img_comparison_name(mode, epoch)
                vis_utils.save_image(self.img_merge, filename)  #HWC
                # input C x H x W
                img_np_rescale = skimage.transform.rescale(np.array(
                    self.img_merge, dtype='float64'),
                                                           0.5,
                                                           order=0)
                img_np_rescale_CHW = np.transpose(img_np_rescale, (2, 0, 1))
                self.writer.add_image('comparison', img_np_rescale_CHW, i)

    def conditional_save_img_comparison_with_intensity2(
            self, mode, i, ele, pred, pred_pure, pred_intensity, epoch):
        # save 8 images for visualization
        if mode == 'val' or mode == 'eval':
            skip = 100
            if i == 0:
                self.img_merge = vis_utils.merge_into_row_with_intensity2(
                    ele, pred, pred_pure, pred_intensity)
                # self.img_merge = vis_utils.merge_into_row(ele, pred)
            elif i % skip == 0 and i < 8 * skip:
                # row = vis_utils.merge_into_row(ele, pred)
                row = vis_utils.merge_into_row_with_intensity2(
                    ele, pred, pred_pure, pred_intensity)
                self.img_merge = vis_utils.add_row(self.img_merge, row)
            elif i == 8 * skip:
                filename = self._get_img_comparison_name(mode, epoch)
                vis_utils.save_image(self.img_merge, filename)  #HWC
                # input C x H x W
                img_np_rescale = skimage.transform.rescale(np.array(
                    self.img_merge, dtype='float64'),
                                                           0.5,
                                                           order=0)
                img_np_rescale_CHW = np.transpose(img_np_rescale, (2, 0, 1))
                self.writer.add_image('comparison', img_np_rescale_CHW, i)

    def save_img_comparison_as_best(self, mode, epoch):
        if mode == 'val':
            filename = self._get_img_comparison_name(mode, epoch, is_best=True)
            vis_utils.save_image(self.img_merge, filename)
            self.writer.add_image('val_comparison_best',
                                  np.transpose(self.img_merge, (2, 0, 1)))

    def get_ranking_error(self, result):
        return getattr(result, self.args.rank_metric)

    def rank_conditional_save_best(self, mode, result, epoch):
        error = self.get_ranking_error(result)
        best_error = self.get_ranking_error(self.best_result)
        is_best = error < best_error
        if is_best and mode == "val":
            self.old_best_result = self.best_result
            self.best_result = result
            self.save_best_txt(result, epoch)
        return is_best

    def rank_conditional_save_best_with_intensity(self, mode, result,
                                                  result_intensity, epoch):
        error_depth = self.get_ranking_error(result)
        error_intensity = self.get_ranking_error(result_intensity)
        error = error_depth / 1000.0 + error_intensity / 5.0
        best_error_depth = self.get_ranking_error(self.best_result)
        best_error_intensity = self.get_ranking_error(
            self.best_result_intensity)
        best_error = best_error_depth / 1000.0 + best_error_intensity / 5.0
        is_best = error < best_error
        if is_best and mode == "val":
            self.old_best_result = self.best_result
            self.old_best_result_intensity = self.best_result_intensity
            self.best_result = result
            self.best_result_intensity = result_intensity
            self.save_best_txt_with_intensity(result, result_intensity, epoch)
            self.writer.add_scalar("eval/best_rmseD", result.rmse, epoch)
            self.writer.add_scalar("eval/best_rmseI", result_intensity.rmse,
                                   epoch)
            self.writer.add_scalar(
                "eval/best_rmseTotal",
                result.rmse / 1000.0 + result_intensity.rmse / 5.0, epoch)
        return is_best

    def conditional_save_pred_named_with_intensity(self, mode, name, pred,
                                                   pred_intensity, epoch):
        # if ("test" in mode or mode == "eval") and self.args.save_pred:
        if ("test" in mode or mode == "val") or mode == 'eval':

            # save images for visualization/ testing
            image_folder = os.path.join(self.output_directory,
                                        mode + "_output_depth")
            # print(name, image_folder, pred.shape)
            if not os.path.exists(image_folder):
                os.makedirs(image_folder)
            img = torch.squeeze(pred.data.cpu()).numpy()
            # filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
            filename = os.path.join(image_folder, name)
            vis_utils.save_depth_as_uint16png(img, filename)

            image_folder = os.path.join(self.output_directory,
                                        mode + "_output_intensity")
            if not os.path.exists(image_folder):
                os.makedirs(image_folder)
            img = torch.squeeze(pred_intensity.data.cpu()).numpy()
            # filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
            filename = os.path.join(image_folder, name)
            vis_utils.save_depth_as_uint16png(img, filename)

    def conditional_save_pred_named(self, mode, name, pred, epoch):
        # if ("test" in mode or mode == "eval") and self.args.save_pred:
        if ("test" in mode or mode == "val") or mode == 'eval':

            # save images for visualization/ testing
            image_folder = os.path.join(self.output_directory,
                                        mode + "_output")
            # print(name, image_folder)
            if not os.path.exists(image_folder):
                os.makedirs(image_folder)
            img = torch.squeeze(pred.data.cpu()).numpy()
            # filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
            filename = os.path.join(image_folder, name)
            vis_utils.save_depth_as_uint16png(img, filename)

    def conditional_save_pred(self, mode, i, pred, epoch):
        if ("test" in mode or mode == "eval") and self.args.save_pred:

            # save images for visualization/ testing
            image_folder = os.path.join(self.output_directory,
                                        mode + "_output")
            if not os.path.exists(image_folder):
                os.makedirs(image_folder)
            img = torch.squeeze(pred.data.cpu()).numpy()
            filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
            vis_utils.save_depth_as_uint16png(img, filename)

    def conditional_summarize(self, mode, avg, is_best):
        print("\n*\nSummary of ", mode, "round")
        print(''
              'RMSE={average.rmse:.3f}\n'
              'MAE={average.mae:.3f}\n'
              'Photo={average.photometric:.3f}\n'
              'iRMSE={average.irmse:.3f}\n'
              'iMAE={average.imae:.3f}\n'
              'squared_rel={average.squared_rel}\n'
              'silog={average.silog}\n'
              'Delta1={average.delta1:.3f}\n'
              'REL={average.absrel:.3f}\n'
              'Lg10={average.lg10:.3f}\n'
              't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time))
        if is_best and mode == "val":
            print("New best model by %s (was %.3f)" %
                  (self.args.rank_metric,
                   self.get_ranking_error(self.old_best_result)))
        elif mode == "val":
            print("(best %s is %.3f)" %
                  (self.args.rank_metric,
                   self.get_ranking_error(self.best_result)))
        print("*\n")

    def conditional_summarize_intensity(self, mode, avg):
        print("\n*\nIntensity Summary of ", mode, "round")
        print(''
              'RMSE={average.rmse:.3f}\n'
              'MAE={average.mae:.3f}\n'
              'Photo={average.photometric:.3f}\n'
              'iRMSE={average.irmse:.3f}\n'
              'iMAE={average.imae:.3f}\n'
              'squared_rel={average.squared_rel}\n'
              'silog={average.silog}\n'
              'Delta1={average.delta1:.3f}\n'
              'REL={average.absrel:.3f}\n'
              'Lg10={average.lg10:.3f}\n'
              't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time))
        print("*\n")
예제 #4
0
파일: main.py 프로젝트: LeonSun0101/YT
def main():
    torch.cuda.set_device(config.cuda_id)
    global args, best_result, output_directory, train_csv, test_csv, batch_num, best_txt
    best_result = Result()
    best_result.set_to_worst()
    batch_num = 0
    output_directory = utils.get_output_directory(args)

    #-----------------#
    # pytorch version #
    #-----------------#

    try:
        torch._utils._rebuild_tensor_v2
    except AttributeError:

        def _rebuild_tensor_v2(storage, storage_offset, size, stride,
                               requires_grad, backward_hooks):
            tensor = torch._utils._rebuild_tensor(storage, storage_offset,
                                                  size, stride)
            tensor.requires_grad = requires_grad
            tensor._backward_hooks = backward_hooks
            return tensor

        torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    file = open(namefile, 'a+')
    file.writelines(
        str("====================================================") +
        str(nowTime) + '\n')
    file.writelines(str("Cuda_id: ") + str(config.cuda_id) + '\n')
    file.writelines(str("NAME: ") + str(config.name) + '\n')
    file.writelines(str("Description: ") + str(config.description) + '\n')
    file.writelines(
        str("model: ") + str(args.arch) + '\n' + str("loss_final: ") +
        str(args.criterion) + '\n' + str("loss_1: ") + str(config.LOSS_1) +
        '\n' + str("batch_size:") + str(args.batch_size) + '\n')
    file.writelines(str("zoom_scale: ") + str(config.zoom_scale) + '\n')
    file.writelines(str("------------------------") + '\n')
    file.writelines(str("Train_dataste: ") + str(config.train_dir) + '\n')
    file.writelines(str("Validation_dataste: ") + str(config.val_dir) + '\n')
    file.writelines(str("------------------------") + '\n')
    file.writelines(str("Input_type: ") + str(config.input) + '\n')
    file.writelines(str("target_type: ") + str(config.target) + '\n')
    file.writelines(str("LOSS--------------------") + '\n')
    file.writelines(str("Loss_num: ") + str(config.loss_num) + '\n')
    file.writelines(
        str("loss_final: ") + str(args.criterion) + '\n' + str("loss_1: ") +
        str(config.LOSS_1) + '\n')
    file.writelines(
        str("loss_0_weight: ") + str(config.LOSS_0_weight) + '\n' +
        str("loss_1_weight: ") + str(config.LOSS_1_weight) + '\n')
    file.writelines(
        str("weight_GT_canny: ") + str(config.weight_GT_canny_loss) + '\n' +
        str("weight_GT_sobel: ") + str(config.weight_GT_sobel_loss) + '\n' +
        str("weight_rgb_sobel: ") + str(config.weight_rgb_sobel_loss) + '\n')
    file.writelines(str("------------------------") + '\n')
    file.writelines(str("target: ") + str(config.target) + '\n')
    file.writelines(str("data_loader_type: ") + str(config.data_loader) + '\n')
    file.writelines(str("lr: ") + str(config.Init_lr) + '\n')
    file.writelines(str("save_fc: ") + str(config.save_fc) + '\n')
    file.writelines(str("Max epoch: ") + str(config.epoch) + '\n')
    file.close()

    # define loss function (criterion) and optimizer,定义误差函数和优化器
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    elif args.criterion == 'l1_canny':
        criterion = criteria.MaskedL1_cannyLoss().cuda()
    #SOBEL
    elif args.criterion == 'l1_from_rgb_sobel':
        criterion = criteria.MaskedL1_from_rgb_sobel_Loss().cuda()
    elif args.criterion == 'l1_from_GT_rgb_sobel':
        criterion = criteria.MaskedL1_from_GT_rgb_sobel_Loss().cuda()
    elif args.criterion == 'l1_from_GT_sobel':
        criterion = criteria.MaskedL1_from_GT_sobel_Loss().cuda()
    elif args.criterion == 'l2_from_GT_sobel_Loss':
        criterion = criteria.MaskedL2_from_GT_sobel_Loss().cuda()
    #CANNY
    elif args.criterion == 'l1_canny_from_GT_canny':
        criterion = criteria.MaskedL1_canny_from_GT_Loss().cuda()

    # Data loading code
    print("=> creating data loaders ...")
    train_dir = config.train_dir
    val_dir = config.val_dir
    train_dataset = YT_dataset(train_dir, config, is_train_set=True)
    val_dataset = YT_dataset(val_dir, config, is_train_set=False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        sampler=None,
        worker_init_fn=lambda work_id: np.random.seed(work_id))
    # worker_init_fn ensures different sampling patterns for each data loading thread

    # set batch size to be 1 for validation
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print("=> data loaders created.")

    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        assert os.path.isfile(best_model_filename), \
        "=> no best model found at '{}'".format(best_model_filename)
        print("=> loading best model '{}'".format(best_model_filename))
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch']
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        validate(val_loader,
                 model,
                 checkpoint['epoch'],
                 1,
                 write_to_file=False)
        return

    elif args.test:
        print("testing...")
        best_model_filename = best_model_dir
        assert os.path.isfile(best_model_filename), \
            "=> no best model found at '{}'".format(best_model_filename)
        print("=> loading best model '{}'".format(best_model_filename))
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch']
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
        optimizer = checkpoint['optimizer']
        for state in optimizer.state.values():
            for k, v in state.items():
                print(type(v))
                if torch.is_tensor(v):
                    state[k] = v.cuda()

        #test(val_loader, model, checkpoint['epoch'], write_to_file=False)
        test(model)
        return

    elif args.resume:
        assert os.path.isfile(config.resume_model_dir), \
            "=> no checkpoint found at '{}'".format(config.resume_model_dir)
        print("=> loading checkpoint '{}'".format(config.resume_model_dir))
        best_model_filename = config.resume_model_dir
        checkpoint = torch.load(best_model_filename)
        args.start_epoch = checkpoint['epoch'] + 1
        best_result = checkpoint['best_result']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        for state in optimizer.state.values():
            for k, v in state.items():
                #print(type(v))
                if torch.is_tensor(v):
                    state[k] = v.cuda(config.cuda_id)

        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

    else:
        print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
        if config.input == 'RGBT':
            in_channels = 4
        elif config.input == 'YT':
            in_channels = 2
        else:
            print("Input type is wrong !")
            return 0
        if args.arch == 'resnet50':  #调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet(layers=50,
                           decoder=args.decoder,
                           output_size=train_dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet50_deconv1_loss0':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_with_deconv(layers=50,
                                       decoder=args.decoder,
                                       output_size=train_dataset.output_size,
                                       in_channels=in_channels,
                                       pretrained=args.pretrained)
        elif args.arch == 'resnet50_deconv1_loss1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_with_deconv_loss(
                layers=50,
                decoder=args.decoder,
                output_size=train_dataset.output_size,
                in_channels=in_channels,
                pretrained=args.pretrained)
        elif args.arch == 'resnet50_direct_deconv1_loss1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_with_direct_deconv(
                layers=50,
                decoder=args.decoder,
                output_size=train_dataset.output_size,
                in_channels=in_channels,
                pretrained=args.pretrained)
        elif args.arch == 'resnet50_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_1(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_2':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_2(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_3':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_3_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3_1(layers=50,
                               decoder=args.decoder,
                               output_size=train_dataset.output_size,
                               in_channels=in_channels,
                               pretrained=args.pretrained)
        elif args.arch == 'resnet50_3_2':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3_2(layers=50,
                               decoder=args.decoder,
                               output_size=train_dataset.output_size,
                               in_channels=in_channels,
                               pretrained=args.pretrained)
        elif args.arch == 'resnet50_3_3':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_3_3(layers=50,
                               decoder=args.decoder,
                               output_size=train_dataset.output_size,
                               in_channels=in_channels,
                               pretrained=args.pretrained)
        elif args.arch == 'resnet50_4':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_4(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_5':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_5(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_7':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_7(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_8':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_8(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_9':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_9(layers=50,
                             decoder=args.decoder,
                             output_size=train_dataset.output_size,
                             in_channels=in_channels,
                             pretrained=args.pretrained)
        elif args.arch == 'resnet50_10':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_10(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_11':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_11(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_11_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_11_1(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_11_without_pretrain':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_11_without_pretrain(
                layers=50,
                decoder=args.decoder,
                output_size=train_dataset.output_size,
                in_channels=in_channels,
                pretrained=args.pretrained)
        elif args.arch == 'resnet50_12':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_12(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_13':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_13(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_14':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_14(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_15':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_16':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_16(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_17':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_17(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_18':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet50_18(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_30':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_30(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_31':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_31(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_32':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_32(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_33':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_33(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_40':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_40(layers=50,
                              decoder=args.decoder,
                              output_size=train_dataset.output_size,
                              in_channels=in_channels,
                              pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_1':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_1(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_2':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_2(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_3':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_3(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_4':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_4(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_5':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_5(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_6':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_6(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_8':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_8(layers=34,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_9':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_9(layers=50,
                                decoder=args.decoder,
                                output_size=train_dataset.output_size,
                                in_channels=in_channels,
                                pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_10':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_10(layers=50,
                                 decoder=args.decoder,
                                 output_size=train_dataset.output_size,
                                 in_channels=in_channels,
                                 pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_11':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_11(layers=50,
                                 decoder=args.decoder,
                                 output_size=train_dataset.output_size,
                                 in_channels=in_channels,
                                 pretrained=args.pretrained)
        elif args.arch == 'resnet50_15_12':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_15_12(layers=50,
                                 decoder=args.decoder,
                                 output_size=train_dataset.output_size,
                                 in_channels=in_channels,
                                 pretrained=args.pretrained)
        elif args.arch == 'resnet18':
            model = ResNet(layers=18,
                           decoder=args.decoder,
                           output_size=train_dataset.output_size,
                           in_channels=in_channels,
                           pretrained=args.pretrained)
        elif args.arch == 'resnet50_20':
            model = ResNet50_20(Bottleneck, [3, 4, 6, 3])
        elif args.arch == 'UNet':
            model = UNet()
        elif args.arch == 'UP_only':
            model = UP_only()
        elif args.arch == 'ResNet_bicubic':  # 调用ResNet的定义实例化model,这里的in_channels是
            model = ResNet_bicubic(layers=50,
                                   decoder=args.decoder,
                                   output_size=train_dataset.output_size,
                                   in_channels=in_channels,
                                   pretrained=args.pretrained)
        elif args.arch == 'VDSR':
            model = VDSR()
        elif args.arch == 'VDSR_without_res':
            model = VDSR_without_res()
        elif args.arch == 'VDSR_16':
            model = VDSR_16()
        elif args.arch == 'VDSR_16_2':
            model = VDSR_16_2()
        elif args.arch == 'Leon_resnet50':
            model = Leon_resnet50()
        elif args.arch == 'Leon_resnet101':
            model = Leon_resnet101()
        elif args.arch == 'Leon_resnet18':
            model = Leon_resnet18()
        elif args.arch == 'Double_resnet50':
            model = Double_resnet50()
        print("=> model created.")

        if args.finetune:
            print("===============loading finetune model=====================")
            assert os.path.isfile(config.fitune_model_dir), \
            "=> no checkpoint found at '{}'".format(config.fitune_model_dir)
            print("=> loading checkpoint '{}'".format(config.fitune_model_dir))
            best_model_filename = config.fitune_model_dir
            checkpoint = torch.load(best_model_filename)
            args.start_epoch = checkpoint['epoch'] + 1
            #best_result = checkpoint['best_result']
            model_fitune = checkpoint['model']
            model_fitune_dict = model_fitune.state_dict()
            model_dict = model.state_dict()
            for k in model_fitune_dict:
                if k in model_dict:
                    #print("There is model k: ",k)
                    model_dict[k] = model_fitune_dict[k]
            #model_dict={k:v for k,v in model_fitune_dict.items() if k in model_dict}
            model_dict.update(model_fitune_dict)
            model.load_state_dict(model_dict)

            #optimizer = checkpoint['optimizer']
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

        #optimizer = torch.optim.SGD(model.parameters(), args.lr,momentum=args.momentum, weight_decay=args.weight_decay)
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     amsgrad=True,
                                     weight_decay=args.weight_decay)
        '''
        optimizer = torch.optim.Adam(
        [
            #{'params':model.base.parameters()}, 3
            {'params': model.re_conv_Y_1.parameters(),'lr':0.0001},
            {'params': model.re_conv_Y_2.parameters(), 'lr': 0.0001},
            {'params': model.re_conv_Y_3.parameters(), 'lr': 0.0001},
            #3
            {'params': model.re_deconv_up0.parameters(), 'lr': 0.0001},
            {'params': model.re_deconv_up1.parameters(), 'lr': 0.0001},
            {'params': model.re_deconv_up2.parameters(), 'lr': 0.0001},
            #3
            {'params': model.re_conv1.parameters(), 'lr': 0.0001},
            {'params': model.re_bn1.parameters(), 'lr': 0.0001},
            {'params': model.re_conv4.parameters(), 'lr': 0.0001},
            #5
            {'params': model.re_ResNet50_layer1.parameters(), 'lr': 0.0001},
            {'params': model.re_ResNet50_layer2.parameters(), 'lr': 0.0001},
            {'params': model.re_ResNet50_layer3.parameters(), 'lr': 0.0001},
            {'params': model.re_ResNet50_layer4.parameters(), 'lr': 0.0001},

            {'params': model.re_bn2.parameters(), 'lr': 0.0001},
            #5
            {'params': model.re_deconcv_res_up1.parameters(), 'lr': 0.0001},
            {'params': model.re_deconcv_res_up2.parameters(), 'lr': 0.0001},
            {'params': model.re_deconcv_res_up3.parameters(), 'lr': 0.0001},
            {'params': model.re_deconcv_res_up4.parameters(), 'lr': 0.0001},

            {'params': model.re_deconv_last.parameters(), 'lr': 0.0001},
            #denoise net 3
            {'params': model.conv_denoise_1.parameters(), 'lr': 0},
            {'params': model.conv_denoise_2.parameters(), 'lr': 0},
            {'params': model.conv_denoise_3.parameters(), 'lr': 0}
        ]
        , lr=args.lr, amsgrad=True, weight_decay=args.weight_decay)
        '''
        for state in optimizer.state.values():
            for k, v in state.items():
                print(type(v))
                if torch.is_tensor(v):
                    state[k] = v.cuda(config.cuda_id)
        print(optimizer)

        # create new csv files with only header
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()


#    writer = SummaryWriter(log_dir='logs')

    model = model.cuda(config.cuda_id)
    #torch.save(model, './net1.pkl')
    for state in optimizer.state.values():
        for k, v in state.items():
            print(type(v))
            if torch.is_tensor(v):
                state[k] = v.cuda()

    print("=> model transferred to GPU.")

    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, val_loader, model, criterion, optimizer, epoch,
              args.lr)  # train for one epoch
예제 #5
0
def main() -> int:
    best_result = Result()
    best_result.set_to_worst()
    args: Any
    args = parser.parse_args()
    dataset = args.data
    if args.modality == 'rgb' and args.num_samples != 0:
        print("number of samples is forced to be 0 when input modality is rgb")
        args.num_samples = 0
    image_shape = (192, 256)  # if "my" in args.arch else (228, 304)

    # create results folder, if not already exists
    if args.transfer_from:
        output_directory = f"{args.transfer_from}_transfer"
    else:
        output_directory = utils.get_output_dir(args)
    args.data = os.path.join(os.environ["DATASET_DIR"], args.data)
    print("output directory :", output_directory)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    elif not args.evaluate:
        raise Exception("output directory allready exists")

    train_csv = os.path.join(output_directory, 'train.csv')
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    # define loss function (criterion) and optimizer
    if args.criterion == 'l2':
        criterion = criteria.MaskedMSELoss().cuda()
    elif args.criterion == 'l1':
        criterion = criteria.MaskedL1Loss().cuda()
    out_channels = 1
    # Data loading code
    print("=> creating data loaders ...")
    traindir = os.path.join(args.data, 'train')
    valdir = traindir if dataset == "SUNRGBD" else os.path.join(
        args.data, 'val')
    DatasetType = choose_dataset_type(dataset)
    train_dataset = DatasetType(traindir,
                                phase='train',
                                modality=args.modality,
                                num_samples=args.num_samples,
                                square_width=args.square_width,
                                output_shape=image_shape,
                                depth_type=args.depth_type)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    print("=> training examples:", len(train_dataset))

    val_dataset = DatasetType(valdir,
                              phase='val',
                              modality=args.modality,
                              num_samples=args.num_samples,
                              square_width=args.square_width,
                              output_shape=image_shape,
                              depth_type=args.depth_type)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("=> validation examples:", len(val_dataset))

    print("=> data loaders created.")

    # evaluation mode
    if args.evaluate:
        best_model_filename = os.path.join(output_directory,
                                           'model_best.pth.tar')
        if os.path.isfile(best_model_filename):
            print("=> loading best model '{}'".format(best_model_filename))
            checkpoint = torch.load(best_model_filename)
            args.start_epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            print("=> loaded best model (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no best model found at '{}'".format(best_model_filename))
        avg_result, avg_result_inside, avg_result_outside, _, results, evaluator = validate(
            val_loader,
            args.square_width,
            args.modality,
            output_directory,
            args.print_freq,
            test_csv,
            model,
            checkpoint['epoch'],
            write_to_file=False)
        write_results(best_txt, avg_result, avg_result_inside,
                      avg_result_outside, checkpoint['epoch'])
        for loss_name, losses in [
            ("rmses", (res.result.rmse for res in results)),
            ("delta1s", (res.result.delta1 for res in results)),
            ("delta2s", (res.result.delta2 for res in results)),
            ("delta3s", (res.result.delta3 for res in results)),
            ("maes", (res.result.mae for res in results)),
            ("absrels", (res.result.absrel for res in results)),
            ("rmses_inside", (res.result_inside.rmse for res in results)),
            ("delta1s_inside", (res.result_inside.delta1 for res in results)),
            ("delta2s_inside", (res.result_inside.delta2 for res in results)),
            ("delta3s_inside", (res.result_inside.delta3 for res in results)),
            ("maes_inside", (res.result_inside.mae for res in results)),
            ("absrels_inside", (res.result_inside.absrel for res in results)),
            ("rmses_outside", (res.result_outside.rmse for res in results)),
            ("delta1s_outside", (res.result_outside.delta1
                                 for res in results)),
            ("delta2s_outside", (res.result_outside.delta2
                                 for res in results)),
            ("delta3s_outside", (res.result_outside.delta3
                                 for res in results)),
            ("maes_outside", (res.result_outside.mae for res in results)),
            ("absrels_outside", (res.result_outside.absrel
                                 for res in results)),
        ]:
            with open(
                    os.path.join(output_directory,
                                 f"validation_{loss_name}.csv"),
                    "w") as csv_file:
                wr = csv.writer(csv_file, quoting=csv.QUOTE_ALL)
                wr.writerow(losses)

        evaluator.save_plot(os.path.join(output_directory, "best.png"))
        return 0

    # optionally resume from a checkpoint
    elif args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_result = checkpoint['best_result']
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return 1
    # create new model
    else:
        if args.transfer_from:
            if os.path.isfile(args.transfer_from):
                print(f"=> loading checkpoint '{args.transfer_from}'")
                checkpoint = torch.load(args.transfer_from)
                args.start_epoch = 0
                model = checkpoint['model']
                print("=> loaded checkpoint")
                train_params = list(model.conv3.parameters()) + list(
                    model.decoder.layer4.parameters(
                    )) if args.train_top_only else model.parameters()
            else:
                print(f"=> no checkpoint found at '{args.transfer_from}'")
                return 1
        else:
            # define model
            print("=> creating Model ({}-{}) ...".format(
                args.arch, args.decoder))
            in_channels = len(args.modality)
            if args.arch == 'resnet50':
                n_layers = 50
            elif args.arch == 'resnet18':
                n_layers = 18
            model = ResNet(layers=n_layers,
                           decoder=args.decoder,
                           in_channels=in_channels,
                           out_channels=out_channels,
                           pretrained=args.pretrained,
                           image_shape=image_shape,
                           skip_type=args.skip_type)
            print("=> model created.")
            train_params = model.parameters()

        adjusting_learning_rate = False
        if args.optimizer == "sgd":
            optimizer = torch.optim.SGD(train_params,
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            adjusting_learning_rate = True
        elif args.optimizer == "adam":
            optimizer = torch.optim.Adam(train_params,
                                         weight_decay=args.weight_decay)
        else:
            raise Exception("We should never be here")

        if adjusting_learning_rate:
            print("=> Learning rate adjustment enabled.")
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, patience=args.adjust_lr_ep, verbose=True)
        # create new csv files with only header
        with open(train_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
        with open(test_csv, 'w') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()
    print(model)
    print("=> model transferred to GPU.")
    epochs_since_best = 0
    train_results = []
    val_results = []
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        res_train, res_train_inside, res_train_outside = train(
            train_loader, model, criterion, optimizer, epoch, args.print_freq,
            train_csv)
        train_results.append((res_train, res_train_inside, res_train_outside))
        # evaluate on validation set
        res_val, res_val_inside, res_val_outside, img_merge, _, _ = validate(
            val_loader, args.square_width, args.modality, output_directory,
            args.print_freq, test_csv, model, epoch, True)
        val_results.append((res_val, res_val_inside, res_val_outside))
        # remember best rmse and save checkpoint
        is_best = res_val.rmse < best_result.rmse
        if is_best:
            epochs_since_best = 0
            best_result = res_val
            write_results(best_txt, res_val, res_val_inside, res_val_outside,
                          epoch)
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)
        else:
            epochs_since_best += 1

        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.arch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

        plot_progress(train_results, val_results, epoch, output_directory)

        if epochs_since_best > args.early_stop_epochs:
            print("early stopping")
        if adjusting_learning_rate:
            scheduler.step(res_val.rmse)
    return 0