コード例 #1
0
def main():
    args = get_arguments()

    # configuration
    CONFIG = Dict(yaml.safe_load(open(args.config)))

    # cpu or cuda
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        torch.backends.cudnn.benchmark = True
    else:
        print('You have to use GPUs because training CNN is computationally expensive.')
        sys.exit(1)

    # Dataloader
    train_data = FlowersDataset(
        CONFIG,
        transform=Compose([
            RandomResizedCrop(size=(CONFIG.height, CONFIG.width)),
            RandomHorizontalFlip(),
            ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            ToTensor(),
            Normalize(mean=get_mean(), std=get_std())
        ]),
        mode='training'
    )

    val_data = FlowersDataset(
        CONFIG,
        transform=Compose([
            ToTensor(),
            Normalize(mean=get_mean(), std=get_std())
        ]),
        mode='validation'
    )

    train_loader = DataLoader(
        train_data,
        batch_size=CONFIG.batch_size,
        shuffle=True,
        num_workers=CONFIG.num_workers,
        pin_memory=True,
        drop_last=True
    )

    val_loader = DataLoader(
        val_data,
        batch_size=1,
        shuffle=False,
        num_workers=CONFIG.num_workers,
        pin_memory=True
    )

    # load model
    print('\n------------------------Loading Model------------------------\n')

    # the number of classes
    n_classes = len(get_cls2id_map())

    # TODO: define a function to get models
    if CONFIG.model == 'resnet18':
        print('ResNet18 will be used as a model.')
        model = torchvision.models.resnet18(pretrained=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(
            in_features=in_features,
            out_features=n_classes,
            bias=True
        )
    elif CONFIG.model == 'resnet34':
        print('ResNet34 will be used as a model.')
        model = torchvision.models.resnet34(pretrained=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(
            in_features=in_features,
            out_features=n_classes,
            bias=True
        )
    else:
        print('There is no model appropriate to your choice. '
              'You have to choose resnet18 or resnet34 as a model in config.yaml')
        sys.exit(1)

    # send the model to cuda/cpu
    model.to(device)

    if CONFIG.optimizer == 'Adam':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.Adam(model.parameters(), lr=CONFIG.learning_rate)
    elif CONFIG.optimizer == 'SGD':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.SGD(
            model.parameters(),
            lr=CONFIG.learning_rate,
            momentum=CONFIG.momentum,
            dampening=CONFIG.dampening,
            weight_decay=CONFIG.weight_decay,
            nesterov=CONFIG.nesterov
        )
    else:
        print(
            'There is no optimizer which suits to your option.'
            'You have to choose SGD or Adam as an optimizer in config.yaml')

    # learning rate scheduler
    if CONFIG.scheduler == 'onplateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=CONFIG.lr_patience
        )
    else:
        scheduler = None

    # resume if you want
    begin_epoch = 0
    best_acc1 = 0
    log = pd.DataFrame(
        columns=[
            'epoch', 'lr', 'train_loss', 'val_loss',
            'train_acc@1', 'val_acc@1', 'train_f1s', 'val_f1s'
        ]
    )

    if args.resume:
        resume_path = os.path.join(CONFIG.result_path, 'checkpoint.pth')
        if os.path.exists(resume_path):
            print('loading the checkpoint...')
            begin_epoch, model, optimizer, best_acc1, scheduler = resume(
                resume_path, model, optimizer, scheduler)
            print('training will start from {} epoch'.format(begin_epoch))
        else:
            print("there is no checkpoint at the result folder")
        if os.path.exists(os.path.join(CONFIG.result_path, 'log.csv')):
            print('loading the log file...')
            log = pd.read_csv(os.path.join(CONFIG.result_path, 'log.csv'))
        else:
            print("there is no log file at the result folder.")
            print('Making a log file...')

    # criterion for loss
    if CONFIG.class_weight:
        criterion = nn.CrossEntropyLoss(
            weight=get_class_weight(n_classes=n_classes).to(device)
        )
    else:
        criterion = nn.CrossEntropyLoss()

    # train and validate model
    print('\n------------------------Start training------------------------\n')

    for epoch in range(begin_epoch, CONFIG.max_epoch):

        # training
        train_loss, train_acc1, train_f1s = train(
            train_loader, model, criterion, optimizer, epoch, device)

        # validation
        val_loss, val_acc1, val_f1s = validate(
            val_loader, model, criterion, device)

        # scheduler
        if scheduler is not None:
            scheduler.step(val_loss)

        # save a model if top1 acc is higher than ever
        if best_acc1 < val_acc1:
            best_acc1 = val_acc1
            torch.save(
                model.state_dict(),
                os.path.join(CONFIG.result_path, 'best_acc1_model.prm')
            )

        # save checkpoint every epoch
        save_checkpoint(
            CONFIG.result_path, epoch, model, optimizer, best_acc1, scheduler)

        # save a model every 10 epoch
        if epoch % 10 == 0 and epoch != 0:
            save_checkpoint(
                CONFIG.result_path, epoch, model, optimizer,
                best_acc1, scheduler, add_epoch2name=True
            )

        # write logs to dataframe and csv file
        tmp = pd.Series([
            epoch,
            optimizer.param_groups[0]['lr'],
            train_loss,
            val_loss,
            train_acc1,
            val_acc1,
            train_f1s,
            val_f1s
        ], index=log.columns
        )

        log = log.append(tmp, ignore_index=True)
        log.to_csv(os.path.join(CONFIG.result_path, 'log.csv'), index=False)

        print(
            'epoch: {}\tlr: {}\ttrain loss: {:.4f}\tval loss: {:.4f}\tval_acc1: {:.5f}\tval_f1s: {:.5f}'
            .format(epoch, optimizer.param_groups[0]['lr'], train_loss,
                    val_loss, val_acc1, val_f1s)
        )

    # save models
    torch.save(
        model.state_dict(), os.path.join(CONFIG.result_path, 'final_model.prm'))
コード例 #2
0
ファイル: train.py プロジェクト: jo-kwsm/ssd
def main():
    args = get_arguments()
    config = get_config(args.config)

    result_path = os.path.dirname(args.config)
    experiment_name = os.path.basename(result_path)

    if os.path.exists(os.path.join(result_path, "final_model.prm")):
        print("Already done.")
        return

    device = get_device(allow_only_gpu=True)

    transform = DataTransform(config.size, get_mean())
    voc_classes = [k for k in get_cls2id_map().keys()]

    train_loader = get_dataloader(
        config.train_csv,
        phase="train",
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
        transform=transform,
        transform_anno=Anno_xml2list(voc_classes),
    )

    val_loader = get_dataloader(
        config.val_csv,
        phase="val",
        batch_size=1,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
        transform=transform,
        transform_anno=Anno_xml2list(voc_classes),
    )

    n_classes = len(voc_classes) + 1
    model = get_model(
        input_size=config.size,
        n_classes=n_classes,
        phase="train",
        pretrained=config.pretrained,
    )
    model.to(device)

    optimizer = optim.SGD(
        model.parameters(),
        lr=config.learning_rate,
        momentum=0.9,
        weight_decay=5e-4
    )

    begin_epoch = 0
    best_loss = float("inf")
    # TODO 評価指標の検討
    log = pd.DataFrame(
        columns=[
            "epoch",
            "lr",
            "train_time[sec]",
            "train_loss",
            "val_time[sec]",
            "val_loss",
        ]
    )

    if args.resume:
        resume_path = os.path.join(result_path, "checkpoint.pth")
        begin_epoch, model, optimizer, best_loss = resume(resume_path, model, optimizer)

        log_path = os.path.join(result_path, "log.csv")
        assert os.path.exists(log_path), "there is no checkpoint at the result folder"
        log = pd.read_csv(log_path)

    criterion = get_criterion(device=device)

    print("---------- Start training ----------")

    for epoch in range(begin_epoch, config.max_epoch):
        start = time.time()
        train_loss = train(
            train_loader,
            model,
            criterion,
            optimizer,
            epoch,
            device,
            interval_of_progress=10,
        )
        train_time = int(time.time() - start)

        start = time.time()
        val_loss = evaluate(
            val_loader,
            model,
            criterion,
            device,
        )
        val_time = int(time.time() - start)

        if best_loss > val_loss:
            best_loss = val_loss
            torch.save(
                model.state_dict(),
                os.path.join(result_path, "best_model.prm"),
            )

        save_checkpoint(result_path, epoch, model, optimizer, best_loss)

        tmp = pd.Series(
            [
                epoch,
                optimizer.param_groups[0]["lr"],
                train_time,
                train_loss,
                val_time,
                val_loss,
            ],
            index=log.columns,
        )

        log = log.append(tmp, ignore_index=True)
        log.to_csv(os.path.join(result_path, "log.csv"), index=False)
        make_graphs(os.path.join(result_path, "log.csv"))

        print(
            """epoch: {}\tepoch time[sec]: {}\tlr: {}\ttrain loss: {:.4f}\t\
            val loss: {:.4f}
            """.format(
                epoch,
                train_time + val_time,
                optimizer.param_groups[0]["lr"],
                train_loss,
                val_loss,
            )
        )

    torch.save(model.state_dict(), os.path.join(result_path, "final_model.prm"))

    os.remove(os.path.join(result_path, "checkpoint.pth"))

    print("Done")
コード例 #3
0
ファイル: train.py プロジェクト: llien30/pytorch_template
def main() -> None:
    args = get_arguments()

    # configuration
    config = get_config(args.config)

    # save log files in the directory which contains config file.
    result_path = os.path.dirname(args.config)
    experiment_name = os.path.basename(result_path)

    # cpu or cuda
    device = get_device(allow_only_gpu=True)

    # Dataloader
    train_transform = Compose([
        RandomResizedCrop(size=(config.height, config.width)),
        RandomHorizontalFlip(),
        ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        ToTensor(),
        Normalize(mean=get_mean(), std=get_std()),
    ])

    val_transform = Compose(
        [ToTensor(), Normalize(mean=get_mean(), std=get_std())])

    train_loader = get_dataloader(
        config.train_csv,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
        transform=train_transform,
    )

    val_loader = get_dataloader(
        config.val_csv,
        batch_size=1,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        transform=val_transform,
    )

    # the number of classes
    n_classes = len(get_cls2id_map())

    # define a model
    model = get_model(config.model, n_classes, pretrained=config.pretrained)

    # send the model to cuda/cpu
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    # keep training and validation log
    begin_epoch = 0
    best_loss = float("inf")
    log = pd.DataFrame(columns=[
        "epoch",
        "lr",
        "train_time[sec]",
        "train_loss",
        "train_acc@1",
        "train_f1s",
        "val_time[sec]",
        "val_loss",
        "val_acc@1",
        "val_f1s",
    ])

    # resume if you want
    if args.resume:
        resume_path = os.path.join(result_path, "checkpoint.pth")
        begin_epoch, model, optimizer, best_loss = resume(
            resume_path, model, optimizer)

        log_path = os.path.join(result_path, "log.csv")
        assert os.path.exists(
            log_path), "there is no checkpoint at the result folder"
        log = pd.read_csv(log_path)

    # criterion for loss
    criterion = get_criterion(config.use_class_weight, config.train_csv,
                              device)

    # Weights and biases
    if not args.no_wandb:
        wandb.init(
            name=experiment_name,
            config=config,
            project="image_classification_template",
            job_type="training",
            dirs="./wandb_result/",
        )
        # Magic
        wandb.watch(model, log="all")

    # train and validate model
    print("---------- Start training ----------")

    for epoch in range(begin_epoch, config.max_epoch):
        # training
        start = time.time()
        train_loss, train_acc1, train_f1s = train(train_loader, model,
                                                  criterion, optimizer, epoch,
                                                  device)
        train_time = int(time.time() - start)

        # validation
        start = time.time()
        val_loss, val_acc1, val_f1s, c_matrix = evaluate(
            val_loader, model, criterion, device)
        val_time = int(time.time() - start)

        # save a model if top1 acc is higher than ever
        if best_loss > val_loss:
            best_loss = val_loss
            torch.save(
                model.state_dict(),
                os.path.join(result_path, "best_model.prm"),
            )

        # save checkpoint every epoch
        save_checkpoint(result_path, epoch, model, optimizer, best_loss)

        # write logs to dataframe and csv file
        tmp = pd.Series(
            [
                epoch,
                optimizer.param_groups[0]["lr"],
                train_time,
                train_loss,
                train_acc1,
                train_f1s,
                val_time,
                val_loss,
                val_acc1,
                val_f1s,
            ],
            index=log.columns,
        )

        log = log.append(tmp, ignore_index=True)
        log.to_csv(os.path.join(result_path, "log.csv"), index=False)

        # save logs to wandb
        if not args.no_wandb:
            wandb.log(
                {
                    "lr": optimizer.param_groups[0]["lr"],
                    "train_time[sec]": train_time,
                    "train_loss": train_loss,
                    "train_acc@1": train_acc1,
                    "train_f1s": train_f1s,
                    "val_time[sec]": val_time,
                    "val_loss": val_loss,
                    "val_acc@1": val_acc1,
                    "val_f1s": val_f1s,
                },
                step=epoch,
            )

        print("""epoch: {}\tepoch time[sec]: {}\tlr: {}\ttrain loss: {:.4f}\t\
            val loss: {:.4f} val_acc1: {:.5f}\tval_f1s: {:.5f}
            """.format(
            epoch,
            train_time + val_time,
            optimizer.param_groups[0]["lr"],
            train_loss,
            val_loss,
            val_acc1,
            val_f1s,
        ))

    # save models
    torch.save(model.state_dict(), os.path.join(result_path,
                                                "final_model.prm"))

    # delete checkpoint
    os.remove(os.path.join(result_path, "checkpoint.pth"))

    print("Done")
コード例 #4
0
def main():
    # argparser
    args = get_arguments()

    # configuration
    CONFIG = Dict(yaml.safe_load(open(args.config)))  # 获取config文件中的各项参数

    seed = args.seed  # 设置随机种子 保证每次训练的初始化时一样的
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True  #将这个flag置为True的话,每次返回的卷积算法将是确定的,Torch 的随机种子为固定值的话,可以保证每次运行网络相同输入的输出是固定的
    torch.cuda.set_device(CONFIG.device)  # 设置模型在那张显卡上跑

    # cpu or cuda
    device = 'cuda' if torch.cuda.is_available() else 'cpu'  # 设置gpu
    if device == 'cuda':
        torch.backends.cudnn.benchmark = True
    else:
        print(
            'You have to use GPUs because training CNN is computationally expensive.'
        )
        sys.exit(1)

    # Dataloader
    # Temporal downsampling is applied to only videos in 50Salads
    print("Dataset: {}\tSplit: {}".format(
        CONFIG.dataset, CONFIG.split))  # 只有在50Salads时 为了保证FPS的一致性 采取抽帧的操作
    print("Batch Size: {}\tNum in channels: {}\tNum Workers: {}".format(
        CONFIG.batch_size, CONFIG.in_channel, CONFIG.num_workers))

    downsamp_rate = 2 if CONFIG.dataset == '50salads' else 1

    train_data = ActionSegmentationDataset(
        CONFIG.dataset,
        transform=Compose([
            ToTensor(),
            TempDownSamp(downsamp_rate),
        ]),
        mode='trainval' if not CONFIG.param_search else 'training',
        split=CONFIG.split,
        dataset_dir=CONFIG.dataset_dir,
        csv_dir=CONFIG.csv_dir)

    train_loader = DataLoader(  # 数据加载 
        train_data,
        batch_size=CONFIG.batch_size,
        shuffle=True,
        num_workers=CONFIG.num_workers,
        drop_last=True if CONFIG.batch_size > 1 else False,
        collate_fn=collate_fn)

    # if you do validation to determine hyperparams
    if CONFIG.param_search:  # 设置为True 表示在训练中 每训练一轮测试一次
        val_data = ActionSegmentationDataset(CONFIG.dataset,
                                             transform=Compose([
                                                 ToTensor(),
                                                 TempDownSamp(downsamp_rate),
                                             ]),
                                             mode='validation',
                                             split=CONFIG.split,
                                             dataset_dir=CONFIG.dataset_dir,
                                             csv_dir=CONFIG.csv_dir)

        val_loader = DataLoader(val_data,
                                batch_size=1,
                                shuffle=False,
                                num_workers=CONFIG.num_workers)

    # load model
    print('\n------------------------Loading Model------------------------\n')

    n_classes = get_n_classes(CONFIG.dataset,
                              dataset_dir=CONFIG.dataset_dir)  # 得到类别数目

    print('Multi Stage TCN will be used as a model.')
    print(
        'stages: {}\tn_features: {}\tn_layers of dilated TCN: {}\tkernel_size of ED-TCN: {}'
        .format(CONFIG.stages, CONFIG.n_features, CONFIG.dilated_n_layers,
                CONFIG.kernel_size))
    model = models.MultiStageTCN(  #模型的初始化
        in_channel=CONFIG.in_channel,
        n_classes=n_classes,
        stages=CONFIG.stages,
        n_features=CONFIG.n_features,
        dilated_n_layers=CONFIG.dilated_n_layers,
        kernel_size=CONFIG.kernel_size)

    # send the model to cuda/cpu
    model.to(CONFIG.device)

    if CONFIG.optimizer == 'Adam':  # 选择优化器
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.Adam(model.parameters(), lr=CONFIG.learning_rate)
    elif CONFIG.optimizer == 'SGD':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.SGD(model.parameters(),
                              lr=CONFIG.learning_rate,
                              momentum=CONFIG.momentum,
                              dampening=CONFIG.dampening,
                              weight_decay=CONFIG.weight_decay,
                              nesterov=CONFIG.nesterov)
    elif CONFIG.optimizer == 'AdaBound':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = adabound.AdaBound(model.parameters(),
                                      lr=CONFIG.learning_rate,
                                      final_lr=CONFIG.final_lr,
                                      weight_decay=CONFIG.weight_decay)
    else:
        print('There is no optimizer which suits to your option.')
        sys.exit(1)

    # learning rate scheduler
    if CONFIG.scheduler == 'onplateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=CONFIG.lr_patience)
    else:
        scheduler = None

    # resume if you want
    columns = ['epoch', 'lr', 'train_loss']

    # if you do validation to determine hyperparams
    if CONFIG.param_groups:
        columns += ['val_loss', 'acc', 'edit']
        columns += [
            "f1s@{}".format(CONFIG.thresholds[i])
            for i in range(len(CONFIG.thresholds))
        ]

    begin_epoch = 0
    best_loss = 100
    log = pd.DataFrame(columns=columns)
    if args.resume:
        if os.path.exists(os.path.join(CONFIG.result_path, 'checkpoint.pth')):
            print('loading the checkpoint...')
            checkpoint = resume(CONFIG.result_path, model, optimizer,
                                scheduler)
            begin_epoch, model, optimizer, best_loss, scheduler = checkpoint
            print('training will start from {} epoch'.format(begin_epoch))
        else:
            print("there is no checkpoint at the result folder")
        if os.path.exists(os.path.join(CONFIG.result_path, 'log.csv')):
            print('loading the log file...')
            log = pd.read_csv(os.path.join(CONFIG.result_path, 'log.csv'))
        else:
            print("there is no log file at the result folder.")
            print('Making a log file...')

    # criterion for loss
    if CONFIG.class_weight:
        class_weight = get_class_weight(CONFIG.dataset,
                                        split=CONFIG.split,
                                        csv_dir=CONFIG.csv_dir)
        class_weight = class_weight.to(device)
    else:
        class_weight = None

    criterion = ActionSegmentationLoss(ce=CONFIG.ce,
                                       tmse=CONFIG.tmse,
                                       weight=class_weight,
                                       ignore_index=255,
                                       tmse_weight=CONFIG.tmse_weight)

    # save best F1  and  edit  model  记录所有的轮数中分数最高的
    Best_F1 = 0
    Best_Edit = 0
    All_Time = 0
    Best_Acc = 0

    # train and validate model
    print(
        '\n---------------------------Start training---------------------------\n'
    )
    for epoch in range(begin_epoch, CONFIG.max_epoch):
        # training
        start = time.time()
        train_loss = train(train_loader, model, criterion, optimizer, epoch,
                           CONFIG, device)
        train_time = (time.time() - start)

        # if you do validation to determine hyperparams
        if CONFIG.param_search:
            start = time.time()
            val_loss, acc, edit_score, f1s = validate(val_loader, model,
                                                      criterion, CONFIG,
                                                      device)
            val_time = (time.time() - start)

            # save a model if top1 acc is higher than ever
            if best_loss > val_loss:
                best_loss = val_loss
                torch.save(
                    model.state_dict(),
                    os.path.join(CONFIG.result_path, 'best_loss_model.prm'))

        # save checkpoint every epoch
        save_checkpoint(CONFIG.result_path, epoch, model, optimizer, best_loss,
                        scheduler)

        # write logs to dataframe and csv file
        tmp = [epoch, optimizer.param_groups[0]['lr'], train_loss]

        # if you do validation to determine hyperparams
        if CONFIG.param_search:
            tmp += [val_loss, acc, edit_score]
            tmp += [f1s[i] for i in range(len(CONFIG.thresholds))]

        tmp_df = pd.Series(tmp, index=log.columns)

        log = log.append(tmp_df, ignore_index=True)
        log.to_csv(os.path.join(CONFIG.result_path, 'log.csv'), index=False)

        # save best F1  and  edit  model
        Best_All = pd.DataFrame(columns=columns)
        if Best_F1 < tmp[6]:
            Best_F1 = tmp[6]
            Best_F1_All = tmp
            Best_All.to_csv(os.path.join(CONFIG.result_path,
                                         'Best_F1_All.csv'),
                            index=False)
            tmp_df.to_csv(os.path.join(CONFIG.result_path, 'Best_F1_All.csv'),
                          index=False)
            torch.save(
                model.state_dict(),
                os.path.join(CONFIG.result_path, 'best_val_F1_model.prm'))
        if Best_Edit < tmp[5]:
            Best_Edit = tmp[5]
            Best_Edit_All = tmp
            Best_All.to_csv(os.path.join(CONFIG.result_path,
                                         'Best_Edit_All.csv'),
                            index=False)
            tmp_df.to_csv(os.path.join(CONFIG.result_path,
                                       'Best_Edit_All.csv'),
                          index=False)
            torch.save(
                model.state_dict(),
                os.path.join(CONFIG.result_path, 'best_val_Edit_model.prm'))
        if Best_Acc < tmp[4]:
            Best_Acc = tmp[4]
            Best_Acc_All = tmp
            tmp_df.to_csv(os.path.join(CONFIG.result_path, 'Best_Acc_All.csv'),
                          index=False)
            torch.save(
                model.state_dict(),
                os.path.join(CONFIG.result_path, 'best_val_Acc_model.prm'))

        if CONFIG.param_search:
            # if you do validation to determine hyperparams
            print(
                'epoch: {}  lr: {:.4f}  train_time: {:.1f}s  val_time: {:.1f}s  train loss: {:.4f}  val loss: {:.4f}  val_acc: {:.4f}  val_edit: {:.4f} F1s: {}'
                .format(epoch, optimizer.param_groups[0]['lr'], train_time,
                        val_time, train_loss, val_loss, acc, edit_score, f1s))
        else:
            print(
                'epoch: {}\tlr: {:.4f}\ttrain_time: {:.1f}min\ttrain loss: {:.4f}'
                .format(epoch, optimizer.param_groups[0]['lr'], train_time,
                        train_loss))
        All_Time = All_Time + train_time + val_time

    # save models
    torch.save(model.state_dict(),
               os.path.join(CONFIG.result_path, 'final_model.prm'))

    print("")
    print("")
    print(
        "**************************************************************  Best Acc ***************************************************************"
    )
    print("")
    print('epoch: {}\tlr: {:.4f}\tval_acc: {:.4f}\tval_edit: {:.4f}\tF1s: {}'.
          format(Best_Acc_All[0], Best_Acc_All[1], Best_Acc_All[4],
                 Best_Acc_All[5], Best_Acc_All[-3:]))
    print("")
    print(
        "**************************************************************  Best Edit **************************************************************"
    )
    print("")
    print('epoch: {}\tlr: {:.4f}\tval_acc: {:.4f}\tval_edit: {:.4f}\tF1s: {}'.
          format(Best_Edit_All[0], Best_Edit_All[1], Best_Edit_All[4],
                 Best_Edit_All[5], Best_Edit_All[-3:]))
    print("")
    print(
        "**************************************************************  Best F1 ***************************************************************"
    )
    print("")
    print('epoch: {}\tlr: {:.4f}\tval_acc: {:.4f}\tval_edit: {:.4f}\tF1s: {}'.
          format(Best_F1_All[0], Best_F1_All[1], Best_F1_All[4],
                 Best_F1_All[5], Best_F1_All[-3:]))
    print("")
    print(
        "**************************************************************   config  ****************************************************************"
    )
    print("")
    print("tmse_weight", CONFIG.tmse_weight, "  optimizer: ", CONFIG.optimizer,
          " scheduler: ", CONFIG.scheduler, "n_classes: ", n_classes)
    print("kernel_size", CONFIG.kernel_size, "  n_features: ",
          CONFIG.n_features, " in_channel: ", CONFIG.in_channel)
    print("Dataset: {}\tSplit: {}".format(CONFIG.dataset, CONFIG.split))
    print("Batch Size: {}\tNum in channels: {}\tNum Workers: {}".format(
        CONFIG.batch_size, CONFIG.in_channel, CONFIG.num_workers))
    print("Dataset: {}\tSplit: {}".format(CONFIG.dataset, CONFIG.split))
    print("train_data: ", len(train_data))
    print("")
    print(
        "***************************************************************************************************************************************"
    )
    print("")
    print("All_time: {:.4f}min".format(All_Time / 60))
    print(CONFIG.result_path)
コード例 #5
0
def main():
    # argparser
    args = get_arguments()

    # configuration
    CONFIG = Dict(yaml.safe_load(open(args.config)))

    seed = args.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    # cpu or cuda
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        torch.backends.cudnn.benchmark = True
    else:
        print(
            'You have to use GPUs because training CNN is computationally expensive.'
        )
        sys.exit(1)

    # Dataloader
    # Temporal downsampling is applied to only videos in 50Salads
    print("Dataset: {}\tSplit: {}".format(CONFIG.dataset, CONFIG.split))
    print("Batch Size: {}\tNum in channels: {}\tNum Workers: {}".format(
        CONFIG.batch_size, CONFIG.in_channel, CONFIG.num_workers))

    downsamp_rate = 2 if CONFIG.dataset == '50salads' else 1

    train_data = ActionSegmentationDataset(
        CONFIG.dataset,
        transform=Compose([
            ToTensor(),
            TempDownSamp(downsamp_rate),
        ]),
        mode='trainval' if not CONFIG.param_search else 'training',
        split=CONFIG.split,
        dataset_dir=CONFIG.dataset_dir,
        csv_dir=CONFIG.csv_dir)

    train_loader = DataLoader(
        train_data,
        batch_size=CONFIG.batch_size,
        shuffle=True,
        num_workers=CONFIG.num_workers,
        drop_last=True if CONFIG.batch_size > 1 else False,
        collate_fn=collate_fn)

    # if you do validation to determine hyperparams
    if CONFIG.param_search:
        val_data = ActionSegmentationDataset(CONFIG.dataset,
                                             transform=Compose([
                                                 ToTensor(),
                                                 TempDownSamp(downsamp_rate),
                                             ]),
                                             mode='validation',
                                             split=CONFIG.split,
                                             dataset_dir=CONFIG.dataset_dir,
                                             csv_dir=CONFIG.csv_dir)

        val_loader = DataLoader(val_data,
                                batch_size=1,
                                shuffle=False,
                                num_workers=CONFIG.num_workers)

    # load model
    print('\n------------------------Loading Model------------------------\n')

    n_classes = get_n_classes(CONFIG.dataset, dataset_dir=CONFIG.dataset_dir)

    print('Multi Stage TCN will be used as a model.')
    print(
        'stages: {}\tn_features: {}\tn_layers of dilated TCN: {}\tkernel_size of ED-TCN: {}'
        .format(CONFIG.stages, CONFIG.n_features, CONFIG.dilated_n_layers,
                CONFIG.kernel_size))
    model = models.MultiStageTCN(in_channel=CONFIG.in_channel,
                                 n_classes=n_classes,
                                 stages=CONFIG.stages,
                                 n_features=CONFIG.n_features,
                                 dilated_n_layers=CONFIG.dilated_n_layers,
                                 kernel_size=CONFIG.kernel_size)

    # send the model to cuda/cpu
    model.to(device)

    if CONFIG.optimizer == 'Adam':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.Adam(model.parameters(), lr=CONFIG.learning_rate)
    elif CONFIG.optimizer == 'SGD':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.SGD(model.parameters(),
                              lr=CONFIG.learning_rate,
                              momentum=CONFIG.momentum,
                              dampening=CONFIG.dampening,
                              weight_decay=CONFIG.weight_decay,
                              nesterov=CONFIG.nesterov)
    elif CONFIG.optimizer == 'AdaBound':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = adabound.AdaBound(model.parameters(),
                                      lr=CONFIG.learning_rate,
                                      final_lr=CONFIG.final_lr,
                                      weight_decay=CONFIG.weight_decay)
    else:
        print('There is no optimizer which suits to your option.')
        sys.exit(1)

    # learning rate scheduler
    if CONFIG.scheduler == 'onplateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=CONFIG.lr_patience)
    else:
        scheduler = None

    # resume if you want
    columns = ['epoch', 'lr', 'train_loss']

    # if you do validation to determine hyperparams
    if CONFIG.param_groups:
        columns += ['val_loss', 'acc', 'edit']
        columns += [
            "f1s@{}".format(CONFIG.thresholds[i])
            for i in range(len(CONFIG.thresholds))
        ]

    begin_epoch = 0
    best_loss = 100
    log = pd.DataFrame(columns=columns)
    if args.resume:
        if os.path.exists(os.path.join(CONFIG.result_path, 'checkpoint.pth')):
            print('loading the checkpoint...')
            checkpoint = resume(CONFIG.result_path, model, optimizer,
                                scheduler)
            begin_epoch, model, optimizer, best_loss, scheduler = checkpoint
            print('training will start from {} epoch'.format(begin_epoch))
        else:
            print("there is no checkpoint at the result folder")
        if os.path.exists(os.path.join(CONFIG.result_path, 'log.csv')):
            print('loading the log file...')
            log = pd.read_csv(os.path.join(CONFIG.result_path, 'log.csv'))
        else:
            print("there is no log file at the result folder.")
            print('Making a log file...')

    # criterion for loss
    if CONFIG.class_weight:
        class_weight = get_class_weight(CONFIG.dataset,
                                        split=CONFIG.split,
                                        csv_dir=CONFIG.csv_dir)
        class_weight = class_weight.to(device)
    else:
        class_weight = None

    criterion = ActionSegmentationLoss(ce=CONFIG.ce,
                                       tmse=CONFIG.tmse,
                                       weight=class_weight,
                                       ignore_index=255,
                                       tmse_weight=CONFIG.tmse_weight)

    # train and validate model
    print(
        '\n---------------------------Start training---------------------------\n'
    )
    for epoch in range(begin_epoch, CONFIG.max_epoch):
        # training
        start = time.time()
        train_loss = train(train_loader, model, criterion, optimizer, epoch,
                           CONFIG, device)
        train_time = (time.time() - start) / 60

        # if you do validation to determine hyperparams
        if CONFIG.param_search:
            start = time.time()
            val_loss, acc, edit_score, f1s = validate(val_loader, model,
                                                      criterion, CONFIG,
                                                      device)
            val_time = (time.time() - start) / 60

            # save a model if top1 acc is higher than ever
            if best_loss > val_loss:
                best_loss = val_loss
                torch.save(
                    model.state_dict(),
                    os.path.join(CONFIG.result_path, 'best_loss_model.prm'))

        # save checkpoint every epoch
        save_checkpoint(CONFIG.result_path, epoch, model, optimizer, best_loss,
                        scheduler)

        # write logs to dataframe and csv file
        tmp = [epoch, optimizer.param_groups[0]['lr'], train_loss]

        # if you do validation to determine hyperparams
        if CONFIG.param_search:
            tmp += [val_loss, acc, edit_score]
            tmp += [f1s[-1][i] for i in range(len(CONFIG.thresholds))]

        tmp_df = pd.Series(tmp, index=log.columns)

        log = log.append(tmp_df, ignore_index=True)
        log.to_csv(os.path.join(CONFIG.result_path, 'log.csv'), index=False)

        if CONFIG.param_search:
            # if you do validation to determine hyperparams
            print(
                'epoch: {}\tlr: {:.4f}\ttrain_time: {:.1f}min\tval_time: {:.1f}min\ttrain loss: {:.4f}\tval loss: {:.4f}\tval_acc: {:.4f}\tval_edit: {:.4f}'
                .format(epoch, optimizer.param_groups[0]['lr'], train_time,
                        val_time, train_loss, val_loss, acc, edit_score))
        else:
            print(
                'epoch: {}\tlr: {:.4f}\ttrain_time: {:.1f}min\ttrain loss: {:.4f}'
                .format(epoch, optimizer.param_groups[0]['lr'], train_time,
                        train_loss))

    # save models
    torch.save(model.state_dict(),
               os.path.join(CONFIG.result_path, 'final_model.prm'))

    print("Done!")
    print("")
コード例 #6
0
ファイル: train.py プロジェクト: jyp0802/asrf
def main() -> None:
    # argparser
    args = get_arguments()

    # configuration
    config = get_config(args.config)

    result_path = os.path.dirname(args.config)

    seed = args.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    # cpu or cuda
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        torch.backends.cudnn.benchmark = True

    # Dataloader
    # Temporal downsampling is applied to only videos in 50Salads
    downsamp_rate = 2 if config.dataset == "50salads" else 1

    train_data = ActionSegmentationDataset(
        config.dataset,
        transform=Compose([ToTensor(), TempDownSamp(downsamp_rate)]),
        mode="trainval" if not config.param_search else "training",
        split=config.split,
        dataset_dir=config.dataset_dir,
        csv_dir=config.csv_dir,
    )

    train_loader = DataLoader(
        train_data,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        drop_last=True if config.batch_size > 1 else False,
        collate_fn=collate_fn,
    )

    # if you do validation to determine hyperparams
    if config.param_search:
        val_data = ActionSegmentationDataset(
            config.dataset,
            transform=Compose([ToTensor(),
                               TempDownSamp(downsamp_rate)]),
            mode="validation",
            split=config.split,
            dataset_dir=config.dataset_dir,
            csv_dir=config.csv_dir,
        )

        val_loader = DataLoader(
            val_data,
            batch_size=1,
            shuffle=False,
            num_workers=config.num_workers,
            collate_fn=collate_fn,
        )

    # load model
    print("---------- Loading Model ----------")

    n_classes = get_n_classes(config.dataset, dataset_dir=config.dataset_dir)

    model = models.ActionSegmentRefinementFramework(
        in_channel=config.in_channel,
        n_features=config.n_features,
        n_classes=n_classes,
        n_stages=config.n_stages,
        n_layers=config.n_layers,
        n_stages_asb=config.n_stages_asb,
        n_stages_brb=config.n_stages_brb,
    )

    # send the model to cuda/cpu
    model.to(device)

    optimizer = get_optimizer(
        config.optimizer,
        model,
        config.learning_rate,
        momentum=config.momentum,
        dampening=config.dampening,
        weight_decay=config.weight_decay,
        nesterov=config.nesterov,
    )

    # resume if you want
    columns = ["epoch", "lr", "train_loss"]

    # if you do validation to determine hyperparams
    if config.param_search:
        columns += ["val_loss", "cls_acc", "edit"]
        columns += [
            "segment f1s@{}".format(config.iou_thresholds[i])
            for i in range(len(config.iou_thresholds))
        ]
        columns += ["bound_acc", "precision", "recall", "bound_f1s"]

    begin_epoch = 0
    best_loss = float("inf")
    log = pd.DataFrame(columns=columns)

    if args.resume:
        if os.path.exists(os.path.join(result_path, "checkpoint.pth")):
            checkpoint = resume(result_path, model, optimizer)
            begin_epoch, model, optimizer, best_loss = checkpoint
            log = pd.read_csv(os.path.join(result_path, "log.csv"))
            print("training will start from {} epoch".format(begin_epoch))
        else:
            print("there is no checkpoint at the result folder")

    # criterion for loss
    if config.class_weight:
        class_weight = get_class_weight(
            config.dataset,
            split=config.split,
            dataset_dir=config.dataset_dir,
            csv_dir=config.csv_dir,
            mode="training" if config.param_search else "trainval",
        )
        class_weight = class_weight.to(device)
    else:
        class_weight = None

    criterion_cls = ActionSegmentationLoss(
        ce=config.ce,
        focal=config.focal,
        tmse=config.tmse,
        gstmse=config.gstmse,
        weight=class_weight,
        ignore_index=255,
        ce_weight=config.ce_weight,
        focal_weight=config.focal_weight,
        tmse_weight=config.tmse_weight,
        gstmse_weight=config.gstmse,
    )

    pos_weight = get_pos_weight(
        dataset=config.dataset,
        split=config.split,
        csv_dir=config.csv_dir,
        mode="training" if config.param_search else "trainval",
    ).to(device)

    criterion_bound = BoundaryRegressionLoss(pos_weight=pos_weight)

    # train and validate model
    print("---------- Start training ----------")

    for epoch in range(begin_epoch, config.max_epoch):
        # training
        train_loss = train(
            train_loader,
            model,
            criterion_cls,
            criterion_bound,
            config.lambda_b,
            optimizer,
            epoch,
            device,
        )

        # if you do validation to determine hyperparams
        if config.param_search:
            (
                val_loss,
                cls_acc,
                edit_score,
                segment_f1s,
                bound_acc,
                precision,
                recall,
                bound_f1s,
            ) = validate(
                val_loader,
                model,
                criterion_cls,
                criterion_bound,
                config.lambda_b,
                device,
                config.dataset,
                config.dataset_dir,
                config.iou_thresholds,
                config.boundary_th,
                config.tolerance,
            )

            # save a model if top1 acc is higher than ever
            if best_loss > val_loss:
                best_loss = val_loss
                torch.save(
                    model.state_dict(),
                    os.path.join(result_path, "best_loss_model.prm"),
                )

        # save checkpoint every epoch
        save_checkpoint(result_path, epoch, model, optimizer, best_loss)

        # write logs to dataframe and csv file
        tmp = [epoch, optimizer.param_groups[0]["lr"], train_loss]

        # if you do validation to determine hyperparams
        if config.param_search:
            tmp += [
                val_loss,
                cls_acc,
                edit_score,
            ]
            tmp += segment_f1s
            tmp += [
                bound_acc,
                precision,
                recall,
                bound_f1s,
            ]

        tmp_df = pd.Series(tmp, index=log.columns)

        log = log.append(tmp_df, ignore_index=True)
        log.to_csv(os.path.join(result_path, "log.csv"), index=False)

        if config.param_search:
            # if you do validation to determine hyperparams
            print(
                "epoch: {}\tlr: {:.4f}\ttrain loss: {:.4f}\tval loss: {:.4f}\tval_acc: {:.4f}\tedit: {:.4f}"
                .format(
                    epoch,
                    optimizer.param_groups[0]["lr"],
                    train_loss,
                    val_loss,
                    cls_acc,
                    edit_score,
                ))
        else:
            print("epoch: {}\tlr: {:.4f}\ttrain loss: {:.4f}".format(
                epoch, optimizer.param_groups[0]["lr"], train_loss))

    # delete checkpoint
    os.remove(os.path.join(result_path, "checkpoint.pth"))

    # save models
    torch.save(model.state_dict(), os.path.join(result_path,
                                                "final_model.prm"))

    print("Done!")
コード例 #7
0
def main():
    args = get_arguments()
    config = get_config(args.config)

    result_path = os.path.dirname(args.config)
    experiment_name = os.path.basename(result_path)

    if os.path.exists(os.path.join(result_path, "final_model_G.prm")):
        print("Already done.")
        return

    device = get_device(allow_only_gpu=True)

    train_loader = get_dataloader(
        csv_file=config.train_csv,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
        transform=ImageTransform(mean=get_mean(), std=get_std()),
    )

    model = get_model(config.model, z_dim=config.z_dim, image_size=config.size)
    for v in model.values():
        v.to(device)

    g_optimizer = torch.optim.Adam(
        model["G"].parameters(),
        config.g_lr,
        [config.beta1, config.beta2],
    )
    d_optimizer = torch.optim.Adam(
        model["D"].parameters(),
        config.d_lr,
        [config.beta1, config.beta2],
    )
    optimizer = {
        "G": g_optimizer,
        "D": d_optimizer,
    }

    begin_epoch = 0
    best_loss = float("inf")
    # TODO 評価指標の検討
    log = pd.DataFrame(
        columns=[
            "epoch",
            "d_lr",
            "g_lr",
            "train_time[sec]",
            "train_loss",
            "train_d_loss",
            "train_g_loss",
        ]
    )

    if args.resume:
        resume_path = os.path.join(result_path, "checkpoint_%s.pth")
        begin_epoch, model, optimizer, best_loss = resume(resume_path, model, optimizer)

        log_path = os.path.join(result_path, "log.csv")
        assert os.path.exists(log_path), "there is no checkpoint at the result folder"
        log = pd.read_csv(log_path)

    criterion = nn.BCEWithLogitsLoss(reduction="mean")

    print("---------- Start training ----------")

    for epoch in range(begin_epoch, config.max_epoch):
        start = time.time()
        train_d_loss, train_g_loss,  = train(
            train_loader,
            model,
            config.model,
            criterion,
            optimizer,
            epoch,
            config.z_dim,
            device,
            interval_of_progress=1,
        )
        train_time = int(time.time() - start)

        if best_loss > train_d_loss + train_g_loss:
            best_loss = train_d_loss + train_g_loss
            for k in model.keys():
                torch.save(
                    model[k].state_dict(),
                    os.path.join(result_path, "best_model_%s.prm" % k),
                )

        save_checkpoint(result_path, epoch, model, optimizer, best_loss)

        tmp = pd.Series(
            [
                epoch,
                optimizer["D"].param_groups[0]["lr"],
                optimizer["G"].param_groups[0]["lr"],
                train_time,
                train_d_loss + train_g_loss,
                train_d_loss,
                train_g_loss,
            ],
            index=log.columns,
        )

        log = log.append(tmp, ignore_index=True)
        log.to_csv(os.path.join(result_path, "log.csv"), index=False)
        make_graphs(os.path.join(result_path, "log.csv"))

        print(
            "epoch: {}\tepoch time[sec]: {}\tD_lr: {}\tG_lr: {}\ttrain loss: {:.4f}\ttrain d_loss: {:.4f}\ttrain g_loss: {:.4f}".format(
                epoch,
                train_time,
                optimizer["D"].param_groups[0]["lr"],
                optimizer["G"].param_groups[0]["lr"],
                train_d_loss + train_g_loss,
                train_d_loss,
                train_g_loss,
            )
        )

    for k in model.keys():
        torch.save(
            model[k].state_dict(),
            os.path.join(result_path, "final_model_%s.prm" % k),
        )

    for k in model.keys():
        os.remove(os.path.join(result_path, "checkpoint_%s.pth" % k))

    print("Done")
コード例 #8
0
def main():
    args = get_arguments()

    # configuration
    CONFIG = Dict(yaml.safe_load(open(args.config)))

    # cpu or cuda
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        torch.backends.cudnn.benchmark = True
    else:
        print(
            'You have to use GPUs because training CNN is computationally expensive.'
        )
        sys.exit(1)

    # writer
    if CONFIG.writer_flag:
        writer = SummaryWriter(CONFIG.result_path)
    else:
        writer = None

    # Dataloader
    print("Dataset: {}".format(CONFIG.dataset))
    print(
        "Batch Size: {}\tNum in channels: {}\tAlignment Size: {}\tNum Workers: {}"
        .format(CONFIG.batch_size, CONFIG.in_channels, CONFIG.align_size,
                CONFIG.num_workers))

    # load vocabulary
    with open(CONFIG.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    train_data = MSR_VTT_Features(dataset_dir=CONFIG.dataset_dir,
                                  feature_dir=CONFIG.feature_dir,
                                  vocab=vocab,
                                  ann_file=CONFIG.ann_file,
                                  mode="train",
                                  align_size=CONFIG.align_size)

    val_data = MSR_VTT_Features(dataset_dir=CONFIG.dataset_dir,
                                feature_dir=CONFIG.feature_dir,
                                vocab=vocab,
                                ann_file=CONFIG.ann_file,
                                mode="val",
                                align_size=CONFIG.align_size)

    train_loader = DataLoader(
        train_data,
        batch_size=CONFIG.batch_size,
        shuffle=True,
        num_workers=CONFIG.num_workers,
        collate_fn=collate_fn,
        drop_last=True if CONFIG.batch_size > 1 else False)

    val_loader = DataLoader(val_data,
                            batch_size=CONFIG.batch_size,
                            shuffle=False,
                            num_workers=CONFIG.num_workers,
                            collate_fn=collate_fn)

    # load encoder, decoder
    print(
        '\n------------------------Loading encoder, decoder------------------------\n'
    )
    encoder = EncoderCNN(CONFIG.in_channels, CONFIG.embed_size)
    decoder = DecoderRNN(CONFIG.embed_size, CONFIG.hidden_size, len(vocab),
                         CONFIG.num_layers)

    # send the encoder, decoder to cuda/cpu
    encoder.to(device)
    decoder.to(device)

    params = list(decoder.parameters()) + list(encoder.linear.parameters())

    if CONFIG.optimizer == 'Adam':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.Adam(params, lr=CONFIG.learning_rate)
    elif CONFIG.optimizer == 'SGD':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = optim.SGD(params,
                              lr=CONFIG.learning_rate,
                              momentum=CONFIG.momentum,
                              dampening=CONFIG.dampening,
                              weight_decay=CONFIG.weight_decay,
                              nesterov=CONFIG.nesterov)
    elif CONFIG.optimizer == 'AdaBound':
        print(CONFIG.optimizer + ' will be used as an optimizer.')
        optimizer = adabound.AdaBound(params,
                                      lr=CONFIG.learning_rate,
                                      final_lr=CONFIG.final_lr,
                                      weight_decay=CONFIG.weight_decay)
    else:
        print('There is no optimizer which suits to your option.')
        sys.exit(1)

    # learning rate scheduler
    if CONFIG.scheduler == 'onplateau' and CONFIG.optimizer == "SGD":
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=CONFIG.lr_patience)
    else:
        scheduler = None

    # resume if you want
    columns = ['epoch', 'lr', 'train_loss', 'val_loss']
    log = pd.DataFrame(columns=columns)

    begin_epoch = 0
    best_loss = 100

    if args.resume:
        if os.path.exists(os.path.join(CONFIG.result_path, 'checkpoint.pth')):
            print('loading the checkpoint...')
            checkpoint = resume(CONFIG.result_path, encoder, decoder,
                                optimizer, scheduler)
            begin_epoch, encoder, decoder, optimizer, best_loss, scheduler = checkpoint
            print('training will start from {} epoch'.format(begin_epoch))
        else:
            print("there is no checkpoint at the result folder")
        if os.path.exists(os.path.join(CONFIG.result_path, 'log.csv')):
            print('loading the log file...')
            log = pd.read_csv(os.path.join(CONFIG.result_path, 'log.csv'))
        else:
            print("there is no log file at the result folder.")
            print('Making a log file...')

    # criterion for loss
    criterion = nn.CrossEntropyLoss()

    # train and validate encoder, decoder
    print(
        '\n---------------------------Start training---------------------------\n'
    )
    train_losses = []
    val_losses = []

    for epoch in range(begin_epoch, CONFIG.max_epoch):
        # training
        train_loss = train(train_loader, encoder, decoder, criterion,
                           optimizer, epoch, CONFIG, device)

        train_losses.append(train_loss)

        # validation
        val_loss = validate(val_loader, encoder, decoder, criterion, CONFIG,
                            device)
        val_losses.append(val_loss)

        # scheduler
        if CONFIG.scheduler == 'onplateau':
            scheduler.step(val_loss)

        # save a encoder, decoder if top1 acc is higher than ever
        if best_loss > val_losses[-1]:
            best_loss = val_losses[-1]
            torch.save(
                encoder.state_dict(),
                os.path.join(CONFIG.result_path, 'best_loss_encoder.prm'))

            torch.save(
                decoder.state_dict(),
                os.path.join(CONFIG.result_path, 'best_loss_decoder.prm'))

        # save checkpoint every epoch
        save_checkpoint(CONFIG.result_path, epoch, encoder, decoder, optimizer,
                        best_loss, scheduler)

        # save checkpoint every 10 epoch
        if epoch % 10 == 0 and epoch != 0:
            save_checkpoint(CONFIG.result_path,
                            epoch,
                            encoder,
                            decoder,
                            optimizer,
                            best_loss,
                            scheduler,
                            add_epoch2name=True)

        # tensorboardx
        if writer is not None:
            writer.add_scalars("loss", {
                'train': train_losses[-1],
                'val': val_losses[-1]
            }, epoch)

        # write logs to dataframe and csv file
        tmp = [
            epoch, optimizer.param_groups[0]['lr'], train_losses[-1],
            val_losses[-1]
        ]
        tmp_df = pd.Series(tmp, index=log.columns)

        log = log.append(tmp_df, ignore_index=True)
        log.to_csv(os.path.join(CONFIG.result_path, 'log.csv'), index=False)

        print('epoch: {}\tlr: {:.4f}\ttrain loss: {:.4f}\tval loss: {:.4f}'.
              format(epoch, optimizer.param_groups[0]['lr'], train_losses[-1],
                     val_losses[-1]))

    # save encoder, decoders
    torch.save(encoder.state_dict(),
               os.path.join(CONFIG.result_path, 'final_encoder.prm'))
    torch.save(decoder.state_dict(),
               os.path.join(CONFIG.result_path, 'final_decoder.prm'))

    print("Done!")
    print("")