2017-04-15 1 views
0

私のデータセット(50x50 RGBイメージ10,000個)を2つのデータセットに分割したいと思います。次のようなものがあります。トーチ:パーティションテンソル

X = torch.rand(10000, 3, 50, 50) 
inds = torch.randperm(X:size(1))[{ { 1, nTrain } }]:long() 
X_selected = X:index(1, inds) 
X_remaining = X:delete(1, inds) 

私はGoogle検索でも、TorchのGitHubドキュメントを取得しています。これどうやってするの?

答えて

1

あなたは

X = torch.rand(10000, 3, 50, 50) 
inds = torch.randperm(X:size(1)):long() 
train_inds = inds:narrow(1, 1, nTrain) 
valid_inds = inds:narrow(1, nTrain + 1, X:size(1) - nTrain) 
X_train = X:index(1, train_inds) 
X_valid = X:index(1, valid_inds) 
この方法を試すことができます
関連する問題