Пример #1
0
def run_probe(args):
    device = torch.device(
        "cuda:" + str(args.cuda_id) if torch.cuda.is_available() else "cpu")
    collect_mode = args.collect_mode if args.method != 'pretrained-rl-agent' else "pretrained_representations"
    tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes(
        args, device, collect_mode=collect_mode, train_mode="probe")
    print("got episodes!")
    observation_shape = tr_eps[0][0].shape
    wandb.config.update(vars(args))

    if args.train_encoder and args.method in train_encoder_methods:
        print("Training encoder from scratch")
        encoder = train_encoder(args)
        encoder.probing = True
        encoder.eval()

    else:
        if args.encoder_type == "Nature":
            encoder = NatureCNN(observation_shape[0], args)
        elif args.encoder_type == "Impala":
            encoder = ImpalaCNN(observation_shape[0], args)

        if args.weights_path == "None":
            if args.method not in probe_only_methods:
                sys.stderr.write(
                    "Probing without loading in encoder weights! Are sure you want to do that??"
                )
        else:
            print(
                "Print loading in encoder weights from probe of type {} from the following path: {}"
                .format(args.method, args.weights_path))
            encoder.load_state_dict(torch.load(args.weights_path))
            encoder.eval()

    torch.set_num_threads(1)

    if args.method == 'majority':
        test_acc, test_f1score = majority_baseline(tr_labels, test_labels,
                                                   wandb)

    else:
        trainer = ProbeTrainer(encoder,
                               wandb,
                               epochs=args.epochs,
                               sample_label=tr_labels[0][0],
                               lr=args.probe_lr,
                               batch_size=args.batch_size,
                               device=device,
                               patience=args.patience,
                               log=False)

        trainer.train(tr_eps, val_eps, tr_labels, val_labels)
        test_acc, test_f1score = trainer.test(test_eps, test_labels)

    print(test_acc)
    print(test_f1score)
    wandb.log(test_acc)
    wandb.log(test_f1score)
def run_probe(args):
    wandb.config.update(vars(args))
    tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels = get_episodes(
        steps=args.probe_steps,
        env_name=args.env_name,
        seed=args.seed,
        num_processes=args.num_processes,
        num_frame_stack=args.num_frame_stack,
        downsample=not args.no_downsample,
        color=args.color,
        entropy_threshold=args.entropy_threshold,
        collect_mode=args.probe_collect_mode,
        train_mode="probe",
        checkpoint_index=args.checkpoint_index,
        min_episode_length=args.batch_size)

    print("got episodes!")

    if args.train_encoder and args.method in train_encoder_methods:
        print("Training encoder from scratch")
        encoder = train_encoder(args)
        encoder.probing = True
        encoder.eval()

    elif args.method in ["pretrained-rl-agent", "majority"]:
        encoder = None

    else:
        observation_shape = tr_eps[0][0].shape
        if args.encoder_type == "Nature":
            encoder = NatureCNN(observation_shape[0], args)
        elif args.encoder_type == "Impala":
            encoder = ImpalaCNN(observation_shape[0], args)

        if args.weights_path == "None":
            if args.method not in probe_only_methods:
                sys.stderr.write(
                    "Probing without loading in encoder weights! Are sure you want to do that??"
                )
        else:
            print(
                "Print loading in encoder weights from probe of type {} from the following path: {}"
                .format(args.method, args.weights_path))
            encoder.load_state_dict(torch.load(args.weights_path))
            encoder.eval()

    torch.set_num_threads(1)

    if args.method == 'majority':
        test_acc, test_f1score = majority_baseline(tr_labels, test_labels,
                                                   wandb)

    else:
        trainer = ProbeTrainer(encoder=encoder,
                               epochs=args.epochs,
                               method_name=args.method,
                               lr=args.probe_lr,
                               batch_size=args.batch_size,
                               patience=args.patience,
                               wandb=wandb,
                               fully_supervised=(args.method == "supervised"),
                               save_dir=wandb.run.dir)

        trainer.train(tr_eps, val_eps, tr_labels, val_labels)
        test_acc, test_f1score = trainer.test(test_eps, test_labels)
        # trainer = SKLearnProbeTrainer(encoder=encoder)
        # test_acc, test_f1score = trainer.train_test(tr_eps, val_eps, tr_labels, val_labels,
        #                                             test_eps, test_labels)

    print(test_acc, test_f1score)
    wandb.log(test_acc)
    wandb.log(test_f1score)
Пример #3
0
from src.dim_baseline import DIMTrainer
from src.encoders import ImpalaCNN, NatureCNN, NatureOneCNN
from src.global_infonce_stdim import GlobalInfoNCESpatioTemporalTrainer
from src.global_local_infonce import GlobalLocalInfoNCESpatioTemporalTrainer
from src.infonce_spatio_temporal import InfoNCESpatioTemporalTrainer
from src.no_action_feedforward_predictor import NaFFPredictorTrainer
from src.spatio_temporal import SpatioTemporalTrainer
from src.utils import get_argparser
from src.vae import VAETrainer
import torch
import wandb

if __name__ == "__main__":
    parser = get_argparser()
    print("1")
    args = parser.parse_args()
    print("2")
    tags = ["pretraining-only"]
    print("3")
    wandb.init(project=args.wandb_proj, entity="curl-atari", tags=tags)
    print("4")
    config = {}
    print("5")
    config.update(vars(args))
    print("6")
    wandb.config.update(config)
    print("7")
    index_array = torch.randperm(823)
    print("8")
    train_encoder(args, index_array)