def train(working_dir, grid_size, learning_rate, batch_size, num_walks,
          model_type, fn):
    train_props, val_props, test_props = get_props(working_dir,
                                                   dtype=np.float32)
    means_stds = np.loadtxt(working_dir + "/means_stds.csv",
                            dtype=np.float32,
                            delimiter=',')

    # filter out redundant qm8 properties
    if train_props.shape[1] == 16:
        filtered_labels = list(range(0, 8)) + list(range(12, 16))
        train_props = train_props[:, filtered_labels]
        val_props = val_props[:, filtered_labels]
        test_props = test_props[:, filtered_labels]

        means_stds = means_stds[:, filtered_labels]
    if model_type == "resnet18":
        model = ResNet(BasicBlock, [2, 2, 2, 2],
                       grid_size,
                       "regression",
                       feat_nums,
                       e_sizes,
                       num_classes=train_props.shape[1])
    elif model_type == "resnet34":
        model = ResNet(BasicBlock, [3, 4, 6, 3],
                       grid_size,
                       "regression",
                       feat_nums,
                       e_sizes,
                       num_classes=train_props.shape[1])
    elif model_type == "resnet50":
        model = ResNet(Bottleneck, [3, 4, 6, 3],
                       grid_size,
                       "regression",
                       feat_nums,
                       e_sizes,
                       num_classes=train_props.shape[1])
    elif model_type == "densenet121":
        model = densenet121(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    elif model_type == "densenet161":
        model = densenet161(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    elif model_type == "densenet169":
        model = densenet169(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    elif model_type == "densenet201":
        model = densenet201(grid_size,
                            "regression",
                            feat_nums,
                            e_sizes,
                            num_classes=train_props.shape[1])
    else:
        print("specify a valid model")
        return
    model.float()
    model.cuda()
    loss_function_train = nn.MSELoss(reduction='none')
    loss_function_val = nn.L1Loss(reduction='none')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # if model_type[0] == "r":
    # 	batch_size = 128
    # 	optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
    # 					   momentum=0.9, weight_decay=5e-4, nesterov=True)
    # elif model_type[0] == "d":
    # 	batch_size = 512
    # 	optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
    # 					   momentum=0.9, weight_decay=1e-4, nesterov=True)
    # else:
    # 	print("specify a vlid model")
    # 	return

    stds = means_stds[1, :]
    tl_list = []
    vl_list = []

    log_file = open(fn + "txt", "w")
    log_file.write("start")
    log_file.flush()

    for file_num in range(num_loads):
        if file_num % 20 == 0:
            model_file = open("../../scratch/" + fn + ".pkl", "wb")
            pickle.dump(model, model_file)
            model_file.close()

        log_file.write("load: " + str(file_num))
        print("load: " + str(file_num))
        # Get new random walks
        if file_num == 0:
            t = time.time()
            train_loader, val_loader, test_loader = get_loaders(working_dir, \
                        file_num, \
                        grid_size, \
                        batch_size, \
                        train_props, \
                        val_props=val_props, \
                        test_props=test_props)
            print("load time")
            print(time.time() - t)
        else:
            file_num = random.randint(0, num_walks - 1)
            t = time.time()
            train_loader, _, _ = get_loaders(working_dir, \
                   file_num, \
                   grid_size, \
                   batch_size, \
                   train_props)
            print("load time")
            print(time.time() - t)
        # Train on set of random walks, can do multiple epochs if desired
        for epoch in range(epochs_per_load):
            model.train()
            t = time.time()
            train_loss_list = []
            train_mae_loss_list = []
            for i, (walks_int, walks_float, props) in enumerate(train_loader):
                walks_int = walks_int.cuda()
                walks_int = walks_int.long()
                walks_float = walks_float.cuda()
                walks_float = walks_float.float()
                props = props.cuda()
                outputs = model(walks_int, walks_float)
                # Individual losses for each item
                loss_mae = torch.mean(loss_function_val(props, outputs), 0)
                train_mae_loss_list.append(loss_mae.cpu().detach().numpy())
                loss = torch.mean(loss_function_train(props, outputs), 0)
                train_loss_list.append(loss.cpu().detach().numpy())
                # Loss converted to single value for backpropagation
                loss = torch.sum(loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            model.eval()
            val_loss_list = []
            with torch.no_grad():
                for i, (walks_int, walks_float,
                        props) in enumerate(val_loader):
                    walks_int = walks_int.cuda()
                    walks_int = walks_int.long()
                    walks_float = walks_float.cuda()
                    walks_float = walks_float.float()
                    props = props.cuda()
                    outputs = model(walks_int, walks_float)
                    # Individual losses for each item
                    loss = loss_function_val(props, outputs)
                    val_loss_list.append(loss.cpu().detach().numpy())
            # ith row of this array is the losses for each label in batch i
            train_loss_arr = np.array(train_loss_list)
            train_mae_arr = np.array(train_mae_loss_list)
            log_file.write("training mse loss\n")
            log_file.write(str(np.mean(train_loss_arr)) + "\n")
            log_file.write("training mae loss\n")
            log_file.write(str(np.mean(train_mae_arr)) + "\n")
            print("training mse loss")
            print(str(np.mean(train_loss_arr)))
            print("training mae loss")
            print(str(np.mean(train_mae_arr)))
            val_loss_arr = np.concatenate(val_loss_list, 0)
            val_loss = np.mean(val_loss_arr, 0)
            log_file.write("val loss\n")
            log_file.write(str(np.mean(val_loss_arr)) + "\n")
            print("val loss")
            print(str(np.mean(val_loss_arr)))
            # Unnormalized loss is for comparison to papers
            tnl = np.mean(train_mae_arr, 0)
            log_file.write("train normalized losses\n")
            log_file.write(" ".join(list(map(str, tnl))) + "\n")
            print("train normalized losses")
            print(" ".join(list(map(str, tnl))))
            log_file.write("val normalized losses\n")
            log_file.write(" ".join(list(map(str, val_loss))) + "\n")
            print("val normalized losses")
            print(" ".join(list(map(str, val_loss))))
            tunl = stds * tnl
            log_file.write("train unnormalized losses\n")
            log_file.write(" ".join(list(map(str, tunl))) + "\n")
            print("train unnormalized losses")
            print(" ".join(list(map(str, tunl))))
            vunl = stds * val_loss
            log_file.write("val unnormalized losses\n")
            log_file.write(" ".join(list(map(str, vunl))) + "\n")
            log_file.write("\n")
            print("val unnormalized losses")
            print(" ".join(list(map(str, vunl))))
            print("\n")
            print("time")
            print(time.time() - t)
        file_num += 1
        log_file.flush()
    log_file.close()
    return model
Ejemplo n.º 2
0
def train(name):
    record = pd.DataFrame(data=np.zeros((1, 4), dtype=np.float),
                          columns=['precision', 'accuracy', 'recall', 'F1'])
    for _ in range(opt.runs):
        seed = random.randint(1, 10000)
        print("Random Seed: ", seed)
        torch.manual_seed(seed)

        # mkdirs for checkpoints output
        os.makedirs(opt.checkpoints_folder, exist_ok=True)
        os.makedirs('%s/%s' % (opt.checkpoints_folder, name), exist_ok=True)
        os.makedirs('report_metrics', exist_ok=True)

        root_dir = 'report_metrics/%s_aug_%s_IMBA/%s' % (
            opt.model, str(opt.n_group), name)
        os.makedirs(root_dir, exist_ok=True)

        # 加载数据集
        path = 'UCRArchive_2018/' + name + '/' + name + '_TRAIN.tsv'
        train_set, n_class = load_ucr(path)

        print('启用平衡数据增强!')
        stratified_train_set = stratify_by_label(train_set)
        data_aug_set = data_aug_by_dft(stratified_train_set, opt.n_group)
        total_set = np.concatenate((train_set, data_aug_set))
        print('Shape of total set', total_set.shape)
        dataset = UcrDataset(total_set, channel_last=opt.channel_last)

        batch_size = int(min(len(dataset) / 10, 16))
        dataloader = UCR_dataloader(dataset, batch_size)

        # Common behavior
        seq_len = dataset.get_seq_len()  # 初始化序列长度
        # 创建分类器对象\损失函数\优化器
        if opt.model == 'r':
            net = ResNet(n_in=seq_len, n_classes=n_class).to(device)
        if opt.model == 'f':
            net = ConvNet(n_in=seq_len, n_classes=n_class).to(device)
        criterion = nn.CrossEntropyLoss().to(device)
        optimizer = optim.Adam(net.parameters(), lr=opt.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=0.5,
                                                         patience=50,
                                                         min_lr=0.0001)

        min_loss = 10000
        print('############# Start to Train ###############')
        net.train()
        for epoch in range(opt.epochs):
            for i, (data, label) in enumerate(dataloader):
                data = data.float()
                data = data.to(device)
                label = label.long()
                label = label.to(device)
                optimizer.zero_grad()
                output = net(data)
                loss = criterion(output, label.view(label.size(0)))
                loss.backward()
                optimizer.step()
                scheduler.step(loss)
                # print('[%d/%d][%d/%d] Loss: %.8f ' % (epoch, opt.epochs, i + 1, len(dataloader), loss.item()))
            if loss < min_loss:
                min_loss = loss
                # End of the epoch,save model
                print('MinLoss: %.10f Saving the best epoch model.....' %
                      min_loss)
                torch.save(
                    net, '%s/%s/%s_%s_best_IMBA.pth' %
                    (opt.checkpoints_folder, name, opt.model, str(
                        opt.n_group)))
        net_path = '%s/%s/%s_%s_best_IMBA.pth' % (opt.checkpoints_folder, name,
                                                  opt.model, str(opt.n_group))
        one_record = eval_accuracy(net_path, name)
        print('The minimum loss is %.8f' % min_loss)
        record = record.append(one_record, ignore_index=True)
    record = record.drop(index=[0])
    record.loc['mean'] = record.mean()
    record.loc['std'] = record.std()
    record.to_csv(root_dir + '/metrics.csv')
    # all_reprot_metrics.loc[name, 'acc_mean'] = record.at['mean', 'accuracy']
    # all_reprot_metrics.loc[name, 'acc_std'] = record.at['std', 'accuracy']
    # all_reprot_metrics.loc[name, 'F1_mean'] = record.at['mean', 'F1']
    # all_reprot_metrics.loc[name, 'F1_std'] = record.at['std', 'F1']

    print('\n')
Ejemplo n.º 3
0
def main():
    if not sys.warnoptions:
        warnings.simplefilter("ignore")

    # --- hyper parameters --- #
    BATCH_SIZE = 256
    LR = 1e-3
    WEIGHT_DECAY = 1e-4
    N_layer = 18
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # --- data process --- #
    # info
    src_path = './data/'
    target_path = './saved/ResNet18/'
    model_path = target_path + 'pkls/'
    pred_path = target_path + 'preds/'

    if not os.path.exists(model_path):
        os.makedirs(model_path)
    if not os.path.exists(pred_path):
        os.makedirs(pred_path)

    # evaluation: num of classify labels & image size
    # output testing id csv
    label2num_dict, num2label_dict = data_evaluation(src_path)

    # load
    train_data = dataLoader(src_path, 'train', label2num_dict)
    train_len = len(train_data)
    test_data = dataLoader(src_path, 'test')

    train_loader = Data.DataLoader(
        dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=12,
    )
    test_loader = Data.DataLoader(
        dataset=test_data,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=12,
    )

    # --- model training --- #
    # fp: for storing data
    fp_train_acc = open(target_path + 'train_acc.txt', 'w')
    fp_time = open(target_path + 'time.txt', 'w')

    # train
    highest_acc, train_acc_seq = 0, []
    loss_funct = nn.CrossEntropyLoss()
    net = ResNet(N_layer).to(device)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=LR,
                                 weight_decay=WEIGHT_DECAY)
    print(net)

    for epoch_i in count(1):
        right_count = 0

        # print('\nTraining epoch {}...'.format(epoch_i))
        # for batch_x, batch_y in tqdm(train_loader):
        for batch_x, batch_y in train_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            # clear gradient
            optimizer.zero_grad()

            # forward & backward
            output = net.forward(batch_x.float())
            highest_out = torch.max(output, 1)[1]
            right_count += sum(batch_y == highest_out).item()

            loss = loss_funct(output, batch_y)
            loss.backward()

            # update parameters
            optimizer.step()

        # calculate accuracy
        train_acc = right_count / train_len
        train_acc_seq.append(train_acc * 100)

        if train_acc > highest_acc:
            highest_acc = train_acc

        # save model
        torch.save(
            net.state_dict(),
            '{}{}_{}_{}.pkl'.format(model_path,
                                    target_path.split('/')[2],
                                    round(train_acc * 1000), epoch_i))

        # write data
        fp_train_acc.write(str(train_acc * 100) + '\n')
        fp_time.write(
            str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) + '\n')
        print('\n{} Epoch {}, Training accuracy: {}'.format(
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), epoch_i,
            train_acc))

        # test
        net.eval()
        test_df = pd.read_csv(src_path + 'testing_data/testing_labels.csv')
        with torch.no_grad():
            for i, (batch_x, _) in enumerate(test_loader):
                batch_x = batch_x.to(device)
                output = net.forward(batch_x.float())
                highest_out = torch.max(output, 1)[1].cpu()
                labels = [
                    num2label_dict[out_j.item()] for out_j in highest_out
                ]
                test_df['label'].iloc[i * BATCH_SIZE:(i + 1) *
                                      BATCH_SIZE] = labels
        test_df.to_csv('{}{}_{}_{}.csv'.format(pred_path,
                                               target_path.split('/')[2],
                                               round(train_acc * 1000),
                                               epoch_i),
                       index=False)
        net.train()

        lr_decay(optimizer)

    fp_train_acc.close()
    fp_time.close()
Ejemplo n.º 4
0
        else:
            from VGG_predict import make_predictions
            from models.VGG_16 import VGG_16_test, VGG_19_test
            
            if args.version == '16':
                make_predictions(args.img_dim, VGG_16_test)
            elif args.version =='19':
                make_predictions(args.img_dim, VGG_19_test)
            else:
                sys.exit('cannot find model you have specified')
    
    elif args.model == 'ResNet':
        import models.ResNet as ResNet

        train_data, train_labels, valid_data, valid_labels, test_data, test_ids = preprocess.get_roof_data(augmented=True, shape=(64, 64))

        if args.train:

            ResNet.train(train_data, train_labels, valid_data, valid_labels, dropout=0.62, num_blocks=3, lr=0.007, weight_decay=0.004)

        else:
            model_vargs = dict(dropout=0.62, num_blocks=3)
            fn = 'results/best.model'
            valid_predictions = ResNet.predict(fn, model_vargs, valid_data)
            test_predictions = ResNet.predict(fn, model_vargs, test_data)
            make_prediction_file.make_prediction_file(test_ids, test_predictions, 'Resnet805_64_64', valid_labels=valid_labels, valid_predictions=valid_predictions)

                                                                                                          
                                                                                                          
        
def train(working_dir, grid_size, learning_rate, batch_size, num_cores):
    process = psutil.Process(os.getpid())
    print(process.memory_info().rss / 1024 / 1024 / 1024)
    train_feat_dict = get_feat_dict(working_dir + "/train_smiles.csv")
    val_feat_dict = get_feat_dict(working_dir + "/val_smiles.csv")
    test_feat_dict = get_feat_dict(working_dir + "/test_smiles.csv")
    # There are about 0.08 gb
    process = psutil.Process(os.getpid())
    print("pre model")
    print(process.memory_info().rss / 1024 / 1024 / 1024)

    torch.set_default_dtype(torch.float64)
    train_props, val_props, test_props = get_props(working_dir, dtype=int)
    print("pre model post props")
    print(process.memory_info().rss / 1024 / 1024 / 1024)
    model = ResNet(BasicBlock, [2, 2, 2, 2],
                   grid_size,
                   "classification",
                   feat_nums,
                   e_sizes,
                   num_classes=train_props.shape[1])
    model.float()
    model.cuda()
    print("model params")
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(pytorch_total_params)
    model.cpu()
    print("model")
    print(process.memory_info().rss / 1024 / 1024 / 1024)
    loss_function = masked_cross_entropy
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    tl_list = []
    vl_list = []
    tmra_list = []
    vmra_list = []

    for file_num in range(num_loads):
        # Get new random walks
        if file_num == 0:
            print("before get_loaders")
            process = psutil.Process(os.getpid())
            print(process.memory_info().rss / 1024 / 1024 / 1024)
            train_loader, val_loader, test_loader = get_loaders(num_cores, \
                     working_dir, \
                     file_num, \
                     grid_size, \
                     batch_size, \
                     train_props, \
                     train_feat_dict, \
                     val_props=val_props, \
                     val_feat_dict=val_feat_dict, \
                     test_props=test_props, \
                     test_feat_dict=test_feat_dict)
        else:
            print("before get_loaders 2")
            process = psutil.Process(os.getpid())
            print(process.memory_info().rss / 1024 / 1024 / 1024)
            train_loader, _, _ = get_loaders(num_cores, \
                   working_dir, \
                   file_num, \
                   grid_size, \
                   batch_size, \
                   train_props, \
                   train_feat_dict)
        # Train on a single set of random walks, can do multiple epochs if desired
        for epoch in range(epochs_per_load):
            model.train()
            model.cuda()
            t = time.time()
            train_loss_list = []
            props_list = []
            outputs_list = []
            # change
            for i, (walks_int, walks_float, props) in enumerate(train_loader):
                walks_int = walks_int.cuda()
                walks_int = walks_int.long()
                walks_float = walks_float.cuda()
                walks_float = walks_float.float()
                props = props.cuda()
                props = props.long()
                props_list.append(props)
                outputs = model(walks_int, walks_float)
                outputs_list.append(outputs)
                loss = loss_function(props, outputs)
                train_loss_list.append(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            props = torch.cat(props_list, 0)
            props = props.cpu().numpy()
            outputs = torch.cat(outputs_list, 0)
            outputs = outputs.detach().cpu().numpy()
            # Get train rocauc value
            train_rocaucs = []
            for i in range(props.shape[1]):
                mask = props[:, i] != 2
                train_rocauc = roc_auc_score(props[mask, i], outputs[mask, i])
                train_rocaucs.append(train_rocauc)
            model.eval()
            with torch.no_grad():
                ds = val_loader.dataset
                walks_int = ds.int_feat_tensor
                walks_float = ds.float_feat_tensor
                props = ds.prop_tensor
                walks_int = walks_int.cuda()
                walks_int = walks_int.long()
                walks_float = walks_float.cuda()
                walks_float = walks_float.float()
                props = props.cuda()
                outputs = model(walks_int, walks_float)
                loss = loss_function(props, outputs)
                props = props.cpu().numpy()
                outputs = outputs.cpu().numpy()
                val_rocaucs = []
                for i in range(props.shape[1]):
                    mask = props[:, i] != 2
                    val_rocauc = roc_auc_score(props[mask, i], outputs[mask,
                                                                       i])
                    val_rocaucs.append(val_rocauc)
            print("load: " + str(file_num) + ", epochs: " + str(epoch))
            print("training loss")
            # Slightly approximate since last batch can be smaller...
            tl = statistics.mean(train_loss_list)
            print(tl)
            print("val loss")
            vl = loss.item()
            print(vl)
            print("train mean roc auc")
            tmra = sum(train_rocaucs) / len(train_rocaucs)
            print(tmra)
            print("val mean roc auc")
            vmra = sum(val_rocaucs) / len(val_rocaucs)
            print(vmra)
            print("time")
            print(time.time() - t)
            tl_list.append(tl)
            vl_list.append(vl)
            tmra_list.append(tmra)
            vmra_list.append(vmra)
            model.cpu()
        file_num += 1
        del train_loader
    save_plot(tl_list, vl_list, tmra_list, vmra_list)
    return model
Ejemplo n.º 6
0
def train(cfg_trn, cfg_vld):
    base_lr = cfg_trn['base_lr']
    batches_per_iter = cfg_trn['batches_per_iter']
    log_after = cfg_trn['log_after']
    checkpoint_after = cfg_trn['checkpoint_after']

    # val_after = cfg_vld['val_after']
    # val_labels = cfg_vld['annF']
    # val_output_name = cfg_vld['']
    # val_images_folder = cfg_vld['root']

    net = ResNet(ResNet_Spec[18])

    dataset = CocoDataset(cfg=cfg_trn)
    train_loader = DataLoader(dataset,
                              batch_size=cfg_trn['batch_size'],
                              num_workers=cfg_trn['num_workers'],
                              shuffle=True)

    optimizer = opt.Adam(net.parameters(),
                         lr=cfg_trn['base_lr'],
                         weight_decay=5e-4)

    num_iter = 0
    current_epoch = 0
    drop_after_epoch = [100, 200, 260]
    scheduler = opt.lr_scheduler.MultiStepLR(optimizer, milestones=drop_after_epoch, gamma=0.333)
    if cfg_trn['checkpoint_path']:
        checkpoint = torch.load(cfg_trn['checkpoint_path'])

        # if from_mobilenet:
        #     load_from_mobilenet(net, checkpoint)
        # else:
        #     load_state(net, checkpoint)
        #     if not weights_only:
        #         optimizer.load_state_dict(checkpoint['optimizer'])
        #         scheduler.load_state_dict(checkpoint['scheduler'])
        #         num_iter = checkpoint['iter']
        #         current_epoch = checkpoint['current_epoch']

    net = DataParallel(net).cuda()
    net.train()
    for epochId in range(current_epoch, 280):
        scheduler.step(epoch=epochId)
        total_losses = [0, 0] * (cfg_trn['num_hourglass_stages'] + 1)  # heatmaps loss, paf loss per stage
        batch_per_iter_idx = 0
        for batch_data in train_loader:
            if batch_per_iter_idx == 0:
                optimizer.zero_grad()

            images = batch_data['image'].cuda()
            keypoint_masks = batch_data['keypoint_mask'].cuda()
            paf_masks = batch_data['paf_mask'].cuda()
            keypoint_maps = batch_data['keypoint_maps'].cuda()
            paf_maps = batch_data['paf_maps'].cuda()

            stages_output = net(images)

            losses = []
            for loss_idx in range(len(total_losses) // 2):
                losses.append(l2loss(stages_output[loss_idx * 2], keypoint_maps, keypoint_masks, images.shape[0]))
                losses.append(l2loss(stages_output[loss_idx * 2 + 1], paf_maps, paf_masks, images.shape[0]))
                total_losses[loss_idx * 2] += losses[-2].item() / batches_per_iter
                total_losses[loss_idx * 2 + 1] += losses[-1].item() / batches_per_iter

            loss = losses[0]
            for loss_idx in range(1, len(losses)):
                loss += losses[loss_idx]
            loss /= batches_per_iter
            loss.backward()
            batch_per_iter_idx += 1
            if batch_per_iter_idx == batches_per_iter:
                optimizer.step()
                batch_per_iter_idx = 0
                num_iter += 1
            else:
                continue

            if num_iter % log_after == 0:
                print('Iter: {}'.format(num_iter))
                for loss_idx in range(len(total_losses) // 2):
                    print('\n'.join(['stage{}_pafs_loss:     {}', 'stage{}_heatmaps_loss: {}']).format(
                        loss_idx + 1, total_losses[loss_idx * 2 + 1] / log_after,
                        loss_idx + 1, total_losses[loss_idx * 2] / log_after))
                for loss_idx in range(len(total_losses)):
                    total_losses[loss_idx] = 0
            if num_iter % checkpoint_after == 0:
                snapshot_name = '{}/checkpoint_iter_{}.pth'.format(checkpoints_folder, num_iter)
                torch.save({'state_dict': net.module.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'iter': num_iter,
                            'current_epoch': epochId},
                           snapshot_name)