Beispiel #1
0
def train(exp=None):
    """
    main function to run the training
    """
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)
    run_dir = "./runs/" + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    # tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(os.path.join(save_dir, "checkpoint.pth.tar")):
        # load existing model
        print("==> loading existing model")
        model_info = torch.load(os.path.join(save_dir, "checkpoin.pth.tar"))
        net.load_state_dict(model_info["state_dict"])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info["optimizer"])
        cur_epoch = model_info["epoch"] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        if exp is not None:
            exp.log_metric("epoch", epoch)
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                "trainloss": "{:.6f}".format(loss_aver),
                "epoch": "{:02d}".format(epoch),
            })
        # tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                # print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    "validloss": "{:.6f}".format(loss_aver),
                    "epoch": "{:02d}".format(epoch),
                })
        # tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f"[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] " +
                     f"train_loss: {train_loss:.6f} " +
                     f"valid_loss: {valid_loss:.6f}")

        # print(print_msg)
        # clear lists to track next epoch
        if exp is not None:
            exp.log_metric("TrainLoss", train_loss)
            exp.log_metric("ValidLoss", valid_loss)
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            "epoch": epoch,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        if epoch % args.save_every == 0 and epoch != 0:
            torch.save(
                model_dict,
                save_dir + "/" + "checkpoint_{}_{:.6f}.pth.tar".format(
                    epoch, valid_loss.item()),
            )
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", "wt") as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", "wt") as f:
        for i in avg_valid_losses:
            print(i, file=f)
Beispiel #2
0
def test():
    '''
    main function to run the training
    '''
    testFolder = MovingMNIST(is_train=False,
                              root='../data/npy-064/',
                              mode ='test',
                              n_frames_input=args.frames_input,
                              n_frames_output=args.frames_output,
                              num_objects=[3])
    testLoader = torch.utils.data.DataLoader(testFolder,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    if args.convlstm:
        encoder_params = convlstm_encoder_params
        decoder_params = convlstm_decoder_params
    if args.convgru:
        encoder_params = convgru_encoder_params
        decoder_params = convgru_decoder_params
    else:
        encoder_params = convgru_encoder_params
        decoder_params = convgru_decoder_params

    #TIMESTAMP = args.timestamp
    # restore args

    CHECKPOINT = args.checkpoint    
    TIMESTAMP = args.timestamp
    save_dir = './save_model/' + TIMESTAMP

    args_path = os.path.join(save_dir, 'cmd_args.txt')
    if os.path.exists(args_path):
        with open(args_path, 'r') as f:
            args.__dict__ = json.load(f)
            args.is_train = False
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1], args.frames_output).cuda()
    net = ED(encoder, decoder)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(save_dir):
        # load existing model
        print('==> loading existing model')
        model_info = torch.load(CHECKPOINT)
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
    else:
        print('there is no such checkpoint in', save_dir)
        exit()
    lossfunction = nn.MSELoss().cuda()
    # to track the testation loss as the model trains
    test_losses = []
    # to track the average training loss per epoch as the model trains
    avg_test_losses = []
    # mini_val_loss = np.inf

    preds = [] 
    ######################
    # testate the model #
    ######################
    with torch.no_grad():
        net.eval()
        t = tqdm(testLoader, leave=False, total=len(testLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            if i == 3000:
                break
            inputs = inputVar.to(device)
            #label = targetVar.to(device)
            pred = net(inputs)
            #loss = lossfunction(pred, label)
            preds.append(pred)
            #loss_aver = loss.item() / args.batch_size
            # record testation loss
            #test_losses.append(loss_aver)

    torch.cuda.empty_cache()
    # print training/testation statistics
    # calculate average loss over an epoch
    #test_loss = np.average(test_losses)
    #avg_test_losses.append(test_loss)
    
    #print_msg = (f'test_loss: {test_loss:.6f}')
    #print(print_msg)

    import pickle
    with open("preds.pkl", "wb") as fp:
        pickle.dump(preds, fp)
def train():
    '''
    main function to run the training
    '''
    # 实例化Encoder和Decoder
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()

    # 实例化ED
    net = ED(encoder, decoder)
    # 运行目录
    run_dir = './runs/' + TIMESTAMP
    # 如果 run_dir 不存在则创建目录
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)

    # 判断CUDA是否可用
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 如果GPU大于一块,使用并行计算
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    # 判断checkpoint.pth.tar文件是否存在
    if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
        # 加载保存的模型
        print('==> loading existing model')
        model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar'))
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
        cur_epoch = model_info['epoch'] + 1
    else:
        # 如果checkpoint.pth.tar不存在则判断save_dir是否存在,不存在则创建
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        # 将当前epoch初始化为0
        cur_epoch = 0
    # 损失函数使用MSELoss
    lossfunction = nn.MSELoss().cuda()
    # 优化器使用Adam
    optimizer = optim.Adam(net.parameters(), lr=args.lr)

    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })

        tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)
Beispiel #4
0
def test():
    '''
    main function to run the testing
    '''
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    # 加载待测试模型
    if os.path.exists(args.model_path):
        # load existing model
        print('==> loading existing model ' + args.model_path)
        model_info = torch.load(args.model_path)
        net.load_state_dict(model_info['state_dict'])
        model_dir = args.model_path.split('/')[-2]
    else:
        raise Exception("Invalid model path!")

    # 创建存储可视化图片的路径
    if not os.path.isdir(args.vis_dir):
        os.makedirs(args.vis_dir)

    class_weights = torch.FloatTensor([1.0, 15.0]).cuda()
    lossfunction = nn.CrossEntropyLoss(weight=class_weights).cuda()
    # to track the testing loss as the model testing
    test_losses = []
    # to track the average testing loss per epoch as the model testing
    avg_test_losses = []

    ######################
    # test the model #
    ######################
    with torch.no_grad():
        net.eval()  # 将module设置为 eval mode,只影响dropout和batchNorm
        # tqdm 进度条
        t = tqdm(testLoader, total=len(testLoader))
        for i, (seq_len, scan_seq, label_seq, mask_seq,
                label_id) in enumerate(t):
            # 序列长度不固定,至少前2帧用来输入,固定预测后3帧
            inputs = inputs = torch.cat((scan_seq, mask_seq.float()),
                                        dim=2).to(device)[:, :-3,
                                                          ...]  # B,S,C,H,W
            label = mask_seq.to(device)[:, (seq_len - 3):, ...]  # B,S,C,H,W
            pred = net(inputs)
            SaveVis(model_dir, i, scan_seq.to(device), mask_seq.to(device),
                    pred)
            seq_number, batch_size, input_channel, height, width = pred.size()
            pred = pred.reshape(-1, input_channel, height,
                                width)  # reshape to B*S,C,H,W
            seq_number, batch_size, input_channel, height, width = label.size()
            label = label.reshape(-1, height, width)  # reshape to B*S,H,W
            label = label.to(device=device, dtype=torch.long)
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / (label.shape[0])
            # record test loss
            test_losses.append(loss_aver)
            t.set_postfix({
                'test_loss': '{:.6f}'.format(loss_aver),
                'cnt': '{:02d}'.format(i)
            })
            # 参数中限制了要测试的样本数量
            if i >= args.sample and args.sample > 0:
                break

    torch.cuda.empty_cache()
    # print test statistics
    # calculate average loss over an epoch
    test_loss = np.average(test_losses)
    avg_test_losses.append(test_loss)

    # epoch_len = len(str(args.epochs))

    test_losses = []
Beispiel #5
0
def test():
    '''
    main function to run the training
    '''
    #TIMESTAMP = args.timestamp
    # restore args

    CHECKPOINT = args.checkpoint
    TIMESTAMP = args.timestamp
    save_dir = './save_model/' + TIMESTAMP

    args_path = os.path.join(save_dir, 'cmd_args.txt')
    if os.path.exists(args_path):
        with open(args_path, 'r') as f:
            args.__dict__ = json.load(f)
            args.is_train = False
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1],
                      args.frames_output).cuda()
    net = ED(encoder, decoder)
    #run_dir = './runs/' + TIMESTAMP
    #if not os.path.isdir(run_dir):
    #    os.makedirs(run_dir)
    #tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    #early_stopping = EarlyStopping(patience=20, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(save_dir):
        # load existing model
        print('==> loading existing model')
        model_info = torch.load(CHECKPOINT)
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
    else:
        print('there is no such checkpoin')
        exit()
    lossfunction = nn.MSELoss().cuda()
    """
    #optimizer = optim.Adam(net.parameters(), lr=args.lr)
    #pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                   verbose=True)
    """
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf

    preds = []
    ######################
    # validate the model #
    ######################
    with torch.no_grad():
        net.eval()
        t = tqdm(validLoader, leave=False, total=len(validLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            if i == 3000:
                break
            inputs = inputVar.to(device)
            label = targetVar.to(device)
            pred = net(inputs)
            loss = lossfunction(pred, label)
            preds.append(pred)
            loss_aver = loss.item() / args.batch_size
            # record validation loss
            valid_losses.append(loss_aver)

    torch.cuda.empty_cache()
    # print training/validation statistics
    # calculate average loss over an epoch
    valid_loss = np.average(valid_losses)
    #avg_valid_losses.append(valid_loss)

    print_msg = (f'valid_loss: {valid_loss:.6f}')
    print(print_msg)

    import pickle
    with open("preds.pkl", "wb") as fp:
        pickle.dump(preds, fp)
Beispiel #6
0
def train():
    '''
    main function to run the training
    '''
    restore = False
    #TIMESTAMP = "2020-03-09T00-00-00"
    if args.timestamp == "NA":
        TIMESTAMP = datetime.now().strftime("%b%d-%H%M%S")
        print('TIMESTAMP', TIMESTAMP)
    else:
        # restore
        restore = True
        TIMESTAMP = args.timestamp
    save_dir = './save_model/' + TIMESTAMP

    if restore:
        # restore args
        with open(os.path.join(save_dir, 'cmd_args.txt'), 'r') as f:
            args.__dict__ = json.load(f)

    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1],
                      args.frames_output).cuda()
    net = ED(encoder, decoder)
    run_dir = './runs/' + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=30, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
        # load existing model
        print('==> loading existing model')
        model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar'))
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
        cur_epoch = model_info['epoch'] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):

            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })

        tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)

    # save args
    if not restore:
        with open(os.path.join(save_dir, 'cmd_args.txt'), 'w+') as f:
            json.dump(args.__dict__, f, indent=2)
Beispiel #7
0
def train(exp=None):
    """
    main function to run the training
    """
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)
    run_dir = "./runs/" + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    # tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(args.checkpoint) and args.continue_train:
        # load existing model
        print("==> loading existing model")
        model_info = torch.load(args.checkpoint)
        net.load_state_dict(model_info["state_dict"])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info["optimizer"])
        cur_epoch = model_info["epoch"] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # pnsr ssim
    avg_psnrs = {}
    avg_ssims = {}
    for j in range(args.frames_output):
        avg_psnrs[j] = []
        avg_ssims[j] = []
    if args.checkdata:
        # Checking dataloader
        print("Checking Dataloader!")
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            assert targetVar.shape == torch.Size([
                args.batchsize, args.frames_output, 1, args.data_h, args.data_w
            ])
            assert inputVar.shape == torch.Size([
                args.batchsize, args.frames_input, 1, args.data_h, args.data_w
            ])
        print("TrainLoader checking is complete!")
        t = tqdm(validLoader, leave=False, total=len(validLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            assert targetVar.shape == torch.Size([
                args.batchsize, args.frames_output, 1, args.data_h, args.data_w
            ])
            assert inputVar.shape == torch.Size([
                args.batchsize, args.frames_input, 1, args.data_h, args.data_w
            ])
        print("ValidLoader checking is complete!")
        # mini_val_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        # to track the training loss as the model trains
        train_losses = []
        # to track the validation loss as the model trains
        valid_losses = []
        psnr_dict = {}
        ssim_dict = {}
        for j in range(args.frames_output):
            psnr_dict[j] = 0
            ssim_dict[j] = 0
        image_log = []
        if exp is not None:
            exp.log_metric("epoch", epoch)
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device)  # B,S,C,H,W
            label = targetVar.to(device)  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / args.batchsize
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                "trainloss": "{:.6f}".format(loss_aver),
                "epoch": "{:02d}".format(epoch),
            })
        # tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                inputs = inputVar.to(device)
                label = targetVar.to(device)
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / args.batchsize
                # record validation loss
                valid_losses.append(loss_aver)

                for j in range(args.frames_output):
                    psnr_dict[j] += psnr(pred[:, j], label[:, j])
                    ssim_dict[j] += ssim(pred[:, j], label[:, j])
                # print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    "validloss": "{:.6f}".format(loss_aver),
                    "epoch": "{:02d}".format(epoch),
                })
                if i % 500 == 499:
                    for k in range(args.frames_output):
                        image_log.append(label[0, k].unsqueeze(0).repeat(
                            1, 3, 1, 1))
                        image_log.append(pred[0, k].unsqueeze(0).repeat(
                            1, 3, 1, 1))
                    upload_images(
                        image_log,
                        epoch,
                        exp=exp,
                        im_per_row=2,
                        rows_per_log=int(len(image_log) / 2),
                    )
        # tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        for j in range(args.frames_output):
            avg_psnrs[j].append(psnr_dict[j] / i)
            avg_ssims[j].append(ssim_dict[j] / i)
        epoch_len = len(str(args.epochs))

        print_msg = (f"[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] " +
                     f"train_loss: {train_loss:.6f} " +
                     f"valid_loss: {valid_loss:.6f}" +
                     f"PSNR_1: {psnr_dict[0] / i:.6f}" +
                     f"SSIM_1: {ssim_dict[0] / i:.6f}")

        # print(print_msg)
        # clear lists to track next epoch
        if exp is not None:
            exp.log_metric("TrainLoss", train_loss)
            exp.log_metric("ValidLoss", valid_loss)
            exp.log_metric("PSNR_1", psnr_dict[0] / i)
            exp.log_metric("SSIM_1", ssim_dict[0] / i)
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            "epoch": epoch,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict(),
            "avg_psnrs": avg_psnrs,
            "avg_ssims": avg_ssims,
            "avg_valid_losses": avg_valid_losses,
            "avg_train_losses": avg_train_losses,
        }
        save_flag = False
        if epoch % args.save_every == 0:
            torch.save(
                model_dict,
                save_dir + "/" +
                "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item()),
            )
            print("Saved" +
                  "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item()))
            save_flag = True
        if avg_psnrs[0][-1] == max(avg_psnrs[0]) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestpsnr_1.pth",
            )
            print("Best psnr found and saved")
            save_flag = True
        if avg_ssims[0][-1] == max(avg_ssims[0]) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestssim_1.pth",
            )
            print("Best ssim found and saved")
            save_flag = True
        if avg_valid_losses[-1] == min(avg_valid_losses) and not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "bestvalidloss.pth",
            )
            print("Best validloss found and saved")
            save_flag = True
        if not save_flag:
            torch.save(
                model_dict,
                save_dir + "/" + "checkpoint.pth",
            )
            print("The latest normal checkpoint saved")
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", "wt") as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", "wt") as f:
        for i in avg_valid_losses:
            print(i, file=f)
Beispiel #8
0
def train():
    '''
    main function to run the training
    '''
    encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
    decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
    net = ED(encoder, decoder)
    run_dir = './runs/' + TIMESTAMP + args.mname
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=200, verbose=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
        # load existing model
        print('==> loading existing model')
        model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar'))
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
        cur_epoch = model_info['epoch'] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0

    class_weights = torch.FloatTensor([1.0, 15.0]).cuda()    
    lossfunction = nn.CrossEntropyLoss(weight=class_weights).cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=5,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    min_train_loss = np.inf
    for epoch in range(cur_epoch, args.epochs + 1):
        print(time.strftime("now time: %Y%m%d_%H:%M", time.localtime(time.time())))
        ###################
        # train the model #
        ###################
        # tqdm 进度条
        t = tqdm(trainLoader, total=len(trainLoader))
        for i, (seq_len, scan_seq, _, mask_seq, _) in enumerate(t):
            # 序列长度不固定,至少前2帧用来输入,固定预测后3帧
            inputs = inputs = torch.cat((scan_seq, mask_seq.float()), dim=2).to(device)[:,:-3,...]   # B,S,C,H,W
            label = mask_seq.to(device)[:,(seq_len-3):,...]     # B,S,C,H,W    
            optimizer.zero_grad()
            net.train()         # 将module设置为 training mode,只影响dropout和batchNorm
            pred = net(inputs)  # B,S,C,H,W

            # 在tensorboard中绘制可视化结果
            if i % 100 == 0:
                grid_ri_lab, grid_pred = get_visualization_example(scan_seq.to(device), mask_seq.to(device), pred, device)
                tb.add_image('visualization/train/rangeImage_gtMask', grid_ri_lab, global_step=epoch)
                tb.add_image('visualization/train/prediction', grid_pred, global_step=epoch)
                
            seq_number, batch_size, input_channel, height, width = pred.size() 
            pred = pred.reshape(-1, input_channel, height, width)  # reshape to B*S,C,H,W
            seq_number, batch_size, input_channel, height, width = label.size() 
            label = label.reshape(-1, height, width) # reshape to B*S,H,W
            label = label.to(device=device, dtype=torch.long)
            # 计算loss
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / (label.shape[0] * batch_size) 
            train_losses.append(loss_aver)
            loss.backward()
            # 防止梯度爆炸,进行梯度裁剪,指定clip_value之后,裁剪的范围就是[-clip_value, clip_value]
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=30.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        tb.add_scalar('TrainLoss', np.average(train_losses), epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            # 将module设置为 eval模式, 只影响dropout和batchNorm
            net.eval()
            # tqdm 进度条
            t = tqdm(validLoader, total=len(validLoader))
            for i, (seq_len, scan_seq, _, mask_seq, _) in enumerate(t):
                if i == 300:    # 限制 validate 数量
                    break
                # 序列长度不固定,至少前2帧用来输入,固定预测后3帧
                inputs = torch.cat((scan_seq, mask_seq.float()), dim=2).to(device)   # B,S,C,H,W
                label = mask_seq.to(device)[:,(seq_len-3):,...]     # B,S,C,H,W    
                pred = net(inputs)

                # 在tensorboard中绘制可视化结果
                if i % 100 == 0:
                    grid_ri_lab, grid_pred = get_visualization_example(scan_seq.to(device), mask_seq.to(device), pred, device)
                    tb.add_image('visualization/valid/rangeImage_gtMask', grid_ri_lab, global_step=epoch)
                    tb.add_image('visualization/valid/prediction', grid_pred, global_step=epoch)
                    
                seq_number, batch_size, input_channel, height, width = pred.size() 
                pred = pred.reshape(-1, input_channel, height, width)  # reshape to B*S,C,H,W
                seq_number, batch_size, input_channel, height, width = label.size() 
                label = label.reshape(-1, height, width) # reshape to B*S,H,W
                label = label.to(device=device, dtype=torch.long)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / (label.shape[0] * batch_size)
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })
                # get_visualization_example(inputs, label, pred)

        tb.add_scalar('ValidLoss', np.average(valid_losses), epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(args.epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        # 保存train loss最低的模型
        if (train_loss < min_train_loss):
            torch.save(model_dict, save_dir + "/" + "best_train_checkpoint.pth.tar")
            min_train_loss = train_loss
        # 保存valid loss最低的模型
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break
    # end for

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)