Esempio n. 1
0
    def __init__(self,
                 n_classes=10,
                 latent_sz=32,
                 ngf=32,
                 init_type='orthogonal',
                 init_gain=0.1,
                 img_sz=32):
        super(CGN, self).__init__()

        # params
        self.batch_size = 1  # default: sample a single image
        self.n_classes = n_classes
        self.latent_sz = latent_sz
        self.label_emb = nn.Embedding(n_classes, n_classes)
        init_sz = img_sz // 4
        inp_dim = self.latent_sz + self.n_classes

        # models
        self.f_shape = nn.Sequential(*shape_layers(inp_dim, 1, ngf, init_sz))
        self.f_text1 = nn.Sequential(*texture_layers(inp_dim, 3, ngf, init_sz),
                                     nn.Tanh())
        self.f_text2 = nn.Sequential(*texture_layers(inp_dim, 3, ngf, init_sz),
                                     nn.Tanh())
        self.shuffler = nn.Sequential(Patch2Image(img_sz, 2),
                                      RandomCrop(img_sz))

        init_net(self, init_type=init_type, init_gain=init_gain)
Esempio n. 2
0
def build_basic_network(in_channels, in_size, out_dim, noisy, sigma0, net_file):
    conv = _build_default_conv(in_channels)

    in_shape = (1, in_channels, in_size, in_size)
    fc_in = utils.count_output_size(in_shape, conv)
    fc_hid = 512
    dims = [fc_in, fc_hid, out_dim]
    if noisy:
        fc = _build_noisy_fc(dims, sigma0)
    else:
        fc = _build_fc(dims)

    net = BasicNetwork(conv, fc)
    utils.init_net(net, net_file)
    return net
Esempio n. 3
0
def build_basic_network(in_channels, in_size, out_dim, noisy, sigma0,
                        net_file):
    conv = _build_default_conv(in_channels)

    in_shape = (1, in_channels, in_size, in_size)
    fc_in = utils.count_output_size(in_shape, conv)
    fc_hid = 512
    dims = [fc_in, fc_hid, out_dim]
    if noisy:
        fc = _build_noisy_fc(dims, sigma0)
    else:
        fc = _build_fc(dims)

    net = BasicNetwork(conv, fc)
    utils.init_net(net, net_file)
    return net
Esempio n. 4
0
def build_dueling_network(in_channels, in_size, out_dim, noisy, sigma0, net_file):
    conv = _build_default_conv(in_channels)

    in_shape = (1, in_channels, in_size, in_size)
    fc_in = utils.count_output_size(in_shape, conv)
    fc_hid = 512
    adv_dims = [fc_in, fc_hid, out_dim]
    val_dims = [fc_in, fc_hid, 1]

    if noisy:
        adv = _build_noisy_fc(adv_dims, sigma0)
        val = _build_noisy_fc(val_dims, sigma0)
    else:
        adv = _build_fc(adv_dims)
        val = _build_fc(val_dims)

    net = DuelingNetwork(conv, adv, val)
    utils.init_net(net, net_file)
    return net
Esempio n. 5
0
def build_dueling_network(in_channels, in_size, out_dim, noisy, sigma0,
                          net_file):
    conv = _build_default_conv(in_channels)

    in_shape = (1, in_channels, in_size, in_size)
    fc_in = utils.count_output_size(in_shape, conv)
    fc_hid = 512
    adv_dims = [fc_in, fc_hid, out_dim]
    val_dims = [fc_in, fc_hid, 1]

    if noisy:
        adv = _build_noisy_fc(adv_dims, sigma0)
        val = _build_noisy_fc(val_dims, sigma0)
    else:
        adv = _build_fc(adv_dims)
        val = _build_fc(val_dims)

    net = DuelingNetwork(conv, adv, val)
    utils.init_net(net, net_file)
    return net
Esempio n. 6
0
    def __init__(self, n_classes, ndf):
        super(DiscConv, self).__init__()
        cin = 4  # RGB + Embedding
        self.label_embedding = nn.Embedding(n_classes, 1)

        def block(cin, cout, ks, st):
            return[
                nn.Conv2d(cin, cout, ks, stride=st, padding=0, bias=False),
                nn.BatchNorm2d(cout),
                nn.LeakyReLU(0.2, True),
            ]

        self.model = nn.Sequential(
            *block(cin, ndf, 3, 1),
            *block(ndf, ndf*2, 3, 1),
            *block(ndf*2, ndf*4, 4, 2),
            *block(ndf*4, ndf*4, 4, 2),
            nn.AvgPool2d(3),
            nn.Conv2d(ndf*4, 1, kernel_size=1, stride=1, padding=0, bias=False),
        )
        init_net(self)
    def setup(self):
        if self.args.mode == 'train':
            self.set_random_seed(self.args.seed)
            # 以{'网络名称':网络对象}的字典的形式申明网络结构(有几个就添加几个)
            self.networks = {'LstmPuncNet': LstmPunctuator(self.args.vocab_size, self.args.embedding_dim, self.args.hidden_size,
                                                           self.args.num_layers, bidirectional=True, num_class=self.args.num_classes)}
            # 网络结构的初始化
            self.networks = init_net(self.networks, self.args.init_type, self.args.gpu_ids)

            # 优化器
            self.optimizers = {'optimizer': torch.optim.SGD(self.networks['LstmPuncNet'].parameters(), lr=self.args.lr, weight_decay=1e-4)}

            # 学习率衰减策略
            self.schedulers = [get_scheduler(optimizer, self.args) for optimizer in list(self.optimizers.values())]

            # 损失函数,可以在这里写一些备用的,方便修改
            self.objectives = {'CrossEntropyLoss': torch.nn.CrossEntropyLoss(ignore_index=self.args.ignore_index).to(self.device),
                               'NLLLoss': torch.nn.NLLLoss().to(self.device)}
        else:
            self.networks = {'LstmPuncNet': LstmPunctuator(self.args.vocab_size, self.args.embedding_dim, self.args.hidden_size,
                                                           self.args.num_layers, bidirectional=True, num_class=self.args.num_classes)}
            self.objectives = {'CrossEntropyLoss': torch.nn.CrossEntropyLoss(ignore_index=self.args.ignore_index).to(self.device),
                               'NLLLoss': torch.nn.NLLLoss().to(self.device)}
            self.load_networks(self.networks, self.args.load_epoch)
Esempio n. 8
0
                                            network_downsampling=network_downsampling,
                                            inlier_percentage=inlier_percentage,
                                            use_store_data=True,
                                            store_data_root=training_data_root,
                                            phase="validation", is_hsv=is_hsv,
                                            num_pre_workers=num_workers, visible_interval=30, rgb_mode="rgb")

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=False,
                                                    num_workers=batch_size)

    depth_estimation_model_student = models.FCDenseNet57(n_classes=1)
    # Initialize the depth estimation network with Kaiming He initialization
    depth_estimation_model_student = utils.init_net(depth_estimation_model_student, type="kaiming", mode="fan_in",
                                                    activation_mode="relu",
                                                    distribution="normal")
    # Multi-GPU running
    depth_estimation_model_student = torch.nn.DataParallel(depth_estimation_model_student)
    # Summary network architecture
    if display_architecture:
        torchsummary.summary(depth_estimation_model_student, input_size=(3, height, width))
    # Optimizer
    optimizer = torch.optim.SGD(depth_estimation_model_student.parameters(), lr=max_lr, momentum=0.9)
    lr_scheduler = scheduler.CyclicLR(optimizer, base_lr=min_lr, max_lr=max_lr, step_size=num_iter)

    # Custom layers
    depth_scaling_layer = models.DepthScalingLayer(epsilon=depth_scaling_epsilon)
    depth_warping_layer = models.DepthWarpingLayer(epsilon=depth_warping_epsilon)
    flow_from_depth_layer = models.FlowfromDepthLayer()
    # Loss functions
Esempio n. 9
0
netG = nets.UNet11(num_classes=1, pretrained='vgg')

# # ngf = 32
# use_dropout = False
# norm_layer = utils.get_norm_layer(norm_type='batch')
# # netG = models.ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
# netG = models.LeakyResnetGenerator(input_nc, output_nc, ngf=6, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
# # netG = models.LeakyResnetGenerator(input_nc, output_nc, ngf=64, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
# ## Unet Generator
# # netG = models.UnetGenerator(input_nc, output_nc, 7, ngf=3, norm_layer=norm_layer, use_dropout=use_dropout)
# utils.init_net(netG, init_type='normal', init_gain=0.02)
summary(netG, input_size=(3, img_size, img_size))

## Building Discriminator
netD = models.Discriminator(input_nc=1, img_size=img_size)
utils.init_net(netD, init_type='normal', init_gain=0.02)
summary(netD, input_size=(1, img_size, img_size))

lr = 0.00002
G_optimizer = Adam(netG.parameters(), lr=lr, betas=(0.9, 0.999))
D_optimizer = Adam(netD.parameters(), lr=lr, betas=(0.9, 0.999))

G_model_path = root / 'G_model.pt'
D_model_path = root / 'D_model.pt'
if G_model_path.exists() and D_model_path.exists():
    state = torch.load(str(G_model_path))
    netG.load_state_dict(state['model'])

    state = torch.load(str(D_model_path))
    epoch = state['epoch']
    step = state['step']
Esempio n. 10
0
    def __init__(self, args=args):
        super().__init__()
        self.args = args
        # random_seed setting
        random_seed = args.randomseed
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if torch.cuda.device_count() > 1:
            torch.cuda.manual_seed_all(random_seed)
        else:
            torch.cuda.manual_seed(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.slomo = model.Slomo(self.args.data_h, self.args.data_w, self.device)
        self.slomo.to(self.device)
        if self.args.init_type != "":
            init_net(self.slomo, self.args.init_type)
            print(self.args.init_type + " initializing slomo done!")
        if self.args.train_continue:
            if not self.args.nocomet and self.args.cometid != "":
                self.comet_exp = ExistingExperiment(
                    previous_experiment=self.args.cometid
                )
            elif not self.args.nocomet and self.args.cometid == "":
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
            else:
                self.comet_exp = None
            self.ckpt_dict = torch.load(self.args.checkpoint)
            self.slomo.load_state_dict(self.ckpt_dict["model_state_dict"])
            self.args.init_learning_rate = self.ckpt_dict["learningRate"]
            self.optimizer = optim.Adam(
                self.slomo.parameters(), lr=self.args.init_learning_rate
            )
            self.optimizer.load_state_dict(self.ckpt_dict["opt_state_dict"])
            print("Pretrained model loaded!")
        else:
            # start logging info in comet-ml
            if not self.args.nocomet:
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
                # self.comet_exp.log_parameters(flatten_opts(self.args))
            else:
                self.comet_exp = None
            self.ckpt_dict = {
                "trainLoss": {},
                "valLoss": {},
                "valPSNR": {},
                "valSSIM": {},
                "learningRate": {},
                "epoch": -1,
                "detail": "End to end Super SloMo.",
                "trainBatchSz": self.args.train_batch_size,
                "validationBatchSz": self.args.validation_batch_size,
            }
            self.optimizer = optim.Adam(
                self.slomo.parameters(), lr=self.args.init_learning_rate
            )
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.args.milestones, gamma=0.1
        )
        # Channel wise mean calculated on adobe240-fps training dataset
        mean = [0.5, 0.5, 0.5]
        std = [1, 1, 1]
        self.normalize = transforms.Normalize(mean=mean, std=std)
        self.transform = transforms.Compose([transforms.ToTensor(), self.normalize])

        trainset = dataloader.SuperSloMo(
            root=self.args.dataset_root + "/train", transform=self.transform, train=True
        )
        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=self.args.train_batch_size,
            num_workers=self.args.num_workers,
            shuffle=True,
        )

        validationset = dataloader.SuperSloMo(
            root=self.args.dataset_root + "/validation",
            transform=self.transform,
            # randomCropSize=(128, 128),
            train=False,
        )
        self.validationloader = torch.utils.data.DataLoader(
            validationset,
            batch_size=self.args.validation_batch_size,
            num_workers=self.args.num_workers,
            shuffle=False,
        )
        ### loss
        self.supervisedloss = supervisedLoss()
        self.best = {
            "valLoss": 99999999,
            "valPSNR": -1,
            "valSSIM": -1,
        }
        self.checkpoint_counter = int(
            (self.ckpt_dict["epoch"] + 1) / self.args.checkpoint_epoch
        )
Esempio n. 11
0
##[TensorboardX](https://github.com/lanpa/tensorboardX)
### For visualizing loss and interpolated frames


# writer = SummaryWriter("log")


###Initialize flow computation and arbitrary-time flow interpolation CNNs.


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
flowComp = model.UNet(6, 4)
flowComp.to(device)
if args.init_type != "":
    init_net(flowComp, args.init_type)
    print(args.init_type + " initializing flowComp done")
ArbTimeFlowIntrp = model.UNet(20, 5)
ArbTimeFlowIntrp.to(device)
if args.init_type != "":
    init_net(ArbTimeFlowIntrp, args.init_type)
    print(args.init_type + " initializing ArbTimeFlowIntrp done")


### Initialization


if args.train_continue:
    if not args.nocomet and args.cometid != "":
        comet_exp = ExistingExperiment(previous_experiment=args.cometid)
    elif not args.nocomet and args.cometid == "":
Esempio n. 12
0
        image_file_names=val_file_names[::20],
        to_augment=True,
        transform=valid_transform)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

    # Building Generator
    netG = UNet_Colorization(num_classes=11, filters_base=16)
    utils.init_net(netG)
    summary(netG, input_size=(3, img_height, img_width))

    # Building Discriminator
    netD = Discriminator(input_nc=3,
                         img_height=img_height,
                         img_width=img_width,
                         filter_base=8,
                         num_block=5)
    utils.init_net(netD)
    summary(netD, input_size=(3, img_height, img_width))

    # Optimizer
    # G_optimizer = Adam(filter(lambda p: p.requires_grad, netG.parameters()), lr=lr, betas=(0.5, 0.999))
    G_optimizer = Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
    D_optimizer = Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
                                            img_height=1024,
                                            factor=0.01)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)

    lr = 1.0e-4
    n_epochs = 200
    add_log = False
    add_output = True
    model = UNet_softmax(num_classes=6,
                         filters_base=6,
                         input_channels=3,
                         add_output=add_output)
    utils.init_net(model)
    summary(model, input_size=(3, 480, 640))

    optimizer = Adam(model.parameters(), lr=lr)

    try:
        model_root = root / "models"
        model_root.mkdir(mode=0o777, parents=False)
    except OSError:
        print("")

    try:
        results_root = root / "results"
        results_root.mkdir(mode=0o777, parents=False)
    except OSError:
        print("")
Esempio n. 14
0
x, y, e = utils.get_data(f"data/{dataset}/{dataset}.npz", num_his, num_pred)
x, y, e = x.to(device), y.to(device), e.to(device)
n = x.shape[2]
adj = utils.get_adj("data/{}/distance.csv".format(dataset), n).to(device)
# adj = None

if not os.path.exists(f'experiment/{net_name}'):
    os.mkdir(f'experiment/{net_name}')
net = AG_JNet(in_feature, dim_exp, layers_sm, layers_tm, n, num_his, num_pred,
              device).to(device)
if net_name == "STGCN":
    net = STGCN(n, in_feature, num_his, num_pred).to(device)

num_params = sum(param.numel() for param in net.parameters())
print('模型参数量:', num_params)
utils.init_net(net)

criterion = nn.MSELoss().to(device)
opt = Adam(net.parameters(), lr=lr)

num_sample = x.shape[0]
split_index = int(train_split * num_sample)
split_index1 = int(val_split * num_sample)
train_x, train_y, val_x, val_y, test_x, test_y, train_e, val_e, test_e, mean, std \
    = utils.normalization(x, y, e, split_index, split_index1)

# 必须打乱数据集,不然loss降不下来。
train_loader = DataLoader(dataset=TensorDataset(train_x, train_e, train_y),
                          batch_size=batch_size,
                          shuffle=True)
val_loader = DataLoader(dataset=TensorDataset(val_x, val_e, val_y),
Esempio n. 15
0
        precompute_root.mkdir(mode=0o777, parents=True)
    except OSError:
        pass

    feature_descriptor_model = models.FCDenseNet(
        in_channels=3,
        down_blocks=(3, 3, 3, 3, 3),
        up_blocks=(3, 3, 3, 3, 3),
        bottleneck_layers=4,
        growth_rate=filter_growth_rate,
        out_chans_first_conv=16,
        feature_length=feature_length)
    # Initialize the network with Kaiming He initialization
    utils.init_net(feature_descriptor_model,
                   type="kaiming",
                   mode="fan_in",
                   activation_mode="relu",
                   distribution="normal")
    # Multi-GPU running
    feature_descriptor_model = torch.nn.DataParallel(feature_descriptor_model)

    # Custom layer
    response_map_generator = models.FeatureResponseGeneratorNoSoftThresholding(
    )
    # Evaluation metric
    matching_accuracy_metric = losses.MatchingAccuracyMetric(threshold=3)

    if trained_model_path.exists():
        print("Loading {:s} ...".format(str(trained_model_path)))
        pre_trained_state = torch.load(str(trained_model_path))
        step = pre_trained_state['step']
        elif opt.image_width == 128:
            import models.vgg_128 as model
    else:
        raise ValueError('Unknown model: %s' % opt.model)

    # define
    frame_predictor = lstm_models.lstm((opt.factor+1)*opt.z_dim, opt.g_dim, opt.rnn_size, opt.predictor_rnn_layers, int(opt.batch_size/len(opt.gpu_ids)))
    posterior_pose = lstm_models.gaussian_lstm(opt.g_dim+opt.factor*opt.z_dim, opt.z_dim, opt.rnn_size, opt.posterior_rnn_layers, int(opt.batch_size/len(opt.gpu_ids)))
    prior = lstm_models.gaussian_lstm(opt.g_dim+opt.factor*opt.z_dim, opt.z_dim, opt.rnn_size, opt.prior_rnn_layers, int(opt.batch_size/len(opt.gpu_ids)))

    cont_encoder = model.cont_encoder(opt.z_dim*opt.factor, opt.channels*opt.n_past)  #g_dim = 64 or 128
    pose_encoder = model.pose_encoder(opt.g_dim, opt.channels)
    decoder = model.decoder(opt.g_dim, opt.channels)

    # init
    frame_predictor = utils.init_net(frame_predictor, init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids)
    posterior_pose = utils.init_net(posterior_pose, init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids)
    prior = utils.init_net(prior, init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids)

    cont_encoder = utils.init_net(cont_encoder, init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids)
    pose_encoder = utils.init_net(pose_encoder, init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids)
    decoder = utils.init_net(decoder, init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids)

    # load
    utils.load_network(frame_predictor, 'frame_predictor', 'last', opt.model_path,device)
    utils.load_network(posterior_pose, 'posterior_pose', 'last', opt.model_path,device)
    utils.load_network(prior, 'prior', 'last', opt.model_path,device)

    utils.load_network(cont_encoder, 'cont_encoder', 'last', opt.model_path,device)
    utils.load_network(pose_encoder, 'pose_encoder', 'last', opt.model_path,device)
    utils.load_network(decoder, 'decoder', 'last', opt.model_path,device)
    # Fix randomness for reproducibility
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(10085)
    np.random.seed(10085)
    random.seed(10085)

    if not Path(args.precompute_root).exists():
        Path(args.precompute_root).mkdir(parents=True)

    model = models.FCDenseNetStd(
        in_channels=3, down_blocks=(4, 4, 4, 4, 4),
        up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4,
        growth_rate=12, out_chans_first_conv=48)
    # Initialize the depth estimation network with Kaiming He initialization
    utils.init_net(model, type="kaiming", mode="fan_in", activation_mode="relu",
                   distribution="normal")
    # Multi-GPU running
    model = torch.nn.DataParallel(model)

    # Load previous depth estimation model
    if Path(args.trained_model_path).exists():
        print("Loading {:s} ...".format(str(args.trained_model_path)))
        state = torch.load(str(args.trained_model_path), encoding='latin1')
        model.load_state_dict(state["model"])
        step = state['step']
        epoch = state['epoch']
        print('Restored model, epoch {}, step {}'.format(epoch, step))
    else:
        print("No previous model detected")
        raise OSError
Esempio n. 18
0
    label_path = "../datasets/lumi/B/test"

    input_file_names = utils.read_lumi_filenames(input_path)
    label_file_names = utils.read_lumi_filenames(label_path)

    dataset = dataset.LumiDataset(input_filenames=input_file_names,
                                  label_filenames=label_file_names,
                                  transform=test_transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=num_workers)

    # Building Generator
    netG = Generator(num_classes=3, filters_base=16)
    netG = utils.init_net(netG)
    summary(netG, input_size=(3, img_height, img_width))

    try:
        model_root = root / "models"
        model_root.mkdir(mode=0o777, parents=False)
    except OSError:
        print("path exists")

    try:
        results_root = root / "results"
        results_root.mkdir(mode=0o777, parents=False)
    except OSError:
        print("path exists")
    print('Restoring mode')
Esempio n. 19
0
    def __init__(self, args=args):
        super().__init__()
        self.args = args
        # random_seed setting
        random_seed = args.randomseed
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if torch.cuda.device_count() > 1:
            torch.cuda.manual_seed_all(random_seed)
        else:
            torch.cuda.manual_seed(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.pretrain_stage = self.args.pretrainstage
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.slomofc = model.Slomofc(
            self.args.data_h, self.args.data_w, self.device, self.pretrain_stage
        )
        self.slomofc.to(self.device)
        if self.pretrain_stage:
            self.learner = ContrastiveLearner(
                self.slomofc,
                image_size=128,
                hidden_layer="avgpool",
                use_momentum=True,  # use momentum for key encoder
                momentum_value=0.999,
                project_hidden=False,  # no projection heads
                use_bilinear=True,  # in paper, logits is bilinear product of query / key
                use_nt_xent_loss=False,  # use regular contrastive loss
                augment_both=False,  # in curl, only the key is augmented
            )
        if self.args.init_type != "":
            init_net(self.slomofc, self.args.init_type)
            print(self.args.init_type + " initializing slomo done!")
        if self.args.train_continue:
            if not self.args.nocomet and self.args.cometid != "":
                self.comet_exp = ExistingExperiment(
                    previous_experiment=self.args.cometid
                )
            elif not self.args.nocomet and self.args.cometid == "":
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
            else:
                self.comet_exp = None
            self.ckpt_dict = torch.load(self.args.checkpoint)
            self.slomofc.load_state_dict(self.ckpt_dict["model_state_dict"])
            self.args.init_learning_rate = self.ckpt_dict["learningRate"]
            if not self.pretrain_stage:
                self.optimizer = optim.Adam(
                    self.slomofc.parameters(), lr=self.args.init_learning_rate
                )
            else:
                self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4)
            self.optimizer.load_state_dict(self.ckpt_dict["opt_state_dict"])
            print("Pretrained model loaded!")
        else:
            # start logging info in comet-ml
            if not self.args.nocomet:
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
                # self.comet_exp.log_parameters(flatten_opts(self.args))
            else:
                self.comet_exp = None
            if not self.pretrain_stage:
                self.ckpt_dict = {
                    "trainLoss": {},
                    "valLoss": {},
                    "valPSNR": {},
                    "valSSIM": {},
                    "learningRate": {},
                    "epoch": -1,
                    "detail": "End to end Super SloMo.",
                    "trainBatchSz": self.args.train_batch_size,
                    "validationBatchSz": self.args.validation_batch_size,
                }
            else:
                self.ckpt_dict = {
                    "conLoss": {},
                    "learningRate": {},
                    "epoch": -1,
                    "detail": "Pretrain_stage of Super SloMo.",
                    "trainBatchSz": self.args.train_batch_size,
                }
            if not self.pretrain_stage:
                self.optimizer = optim.Adam(
                    self.slomofc.parameters(), lr=self.args.init_learning_rate
                )
            else:
                self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.args.milestones, gamma=0.1
        )
        # Channel wise mean calculated on adobe240-fps training dataset
        if not self.pretrain_stage:
            mean = [0.5, 0.5, 0.5]
            std = [1, 1, 1]
            self.normalize = transforms.Normalize(mean=mean, std=std)
            self.transform = transforms.Compose([transforms.ToTensor(), self.normalize])
        else:
            self.transform = transforms.Compose([transforms.ToTensor()])

        trainset = dataloader.SuperSloMo(
            root=self.args.dataset_root + "/train", transform=self.transform, train=True
        )
        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=self.args.train_batch_size,
            num_workers=self.args.num_workers,
            shuffle=True,
        )
        if not self.pretrain_stage:
            validationset = dataloader.SuperSloMo(
                root=self.args.dataset_root + "/validation",
                transform=self.transform,
                # randomCropSize=(128, 128),
                train=False,
            )
            self.validationloader = torch.utils.data.DataLoader(
                validationset,
                batch_size=self.args.validation_batch_size,
                num_workers=self.args.num_workers,
                shuffle=False,
            )
        ### loss
        if not self.pretrain_stage:
            self.supervisedloss = supervisedLoss()
            self.best = {
                "valLoss": 99999999,
                "valPSNR": -1,
                "valSSIM": -1,
            }
        else:
            self.best = {
                "conLoss": 99999999,
            }
        self.checkpoint_counter = int(
            (self.ckpt_dict["epoch"] + 1) / self.args.checkpoint_epoch
        )